GAN for Comic Faces Paired Generation
Model Overview
This model implements a Generative Adversarial Network (GAN) with a UNet generator and a PatchGAN discriminator. The network is designed to generate paired images of comic faces based on a synthetic dataset of comic faces. The model aims to generate high-quality image pairs where the first image is transformed into a second target image (e.g., photo-to-cartoon or cartoon-to-photo transformations).
- Dataset: Comic Faces Paired Synthetic Dataset
- Batch Size: 32
- Input Shape: (3, 256, 256) (RGB Images)
- Output Shape: (3, 256, 256)
Model Architecture
Generator: UNet
The generator uses a UNet architecture, which is designed for image-to-image translation tasks. It has an encoder-decoder structure with skip connections, allowing for high-resolution output. The architecture includes the following layers:
Encoder Path (Contracting Path):
The encoder consists of DoubleConv layers that progressively downsample the input image to extract features. It uses MaxPool2d to reduce spatial dimensions.Bottleneck:
The deepest layer of the network (with 1024 feature channels) processes the smallest version of the image.Decoder Path (Expanding Path):
The decoder uses Upsample layers to progressively increase the spatial dimensions and DoubleConv layers to refine the output. Skip connections are used to combine features from the encoder path.Final Convolution:
The final layer outputs the transformed image using a 1x1 convolution.
Discriminator: PatchGANDiscriminator
The discriminator uses a PatchGAN architecture, which classifies patches of the image as real or fake. The discriminator works by processing the input image and output image pair (3 channels for the input image + 3 channels for the generated output). It progressively reduces the spatial dimensions using Conv2d and LeakyReLU activations, while normalizing each layer with InstanceNorm2d. The final output is a probability score indicating whether the patch is real or fake.
Generator Code (UNet):
class UNet(nn.Module, PyTorchModelHubMixin):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
# Contracting Path (Encoder)
self.down_conv1 = DoubleConv(in_channels, 64)
self.down_conv2 = DoubleConv(64, 128)
self.down_conv3 = DoubleConv(128, 256)
self.down_conv4 = DoubleConv(256, 512)
self.down_conv5 = DoubleConv(512, 1024)
# Downsampling
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
# Upsampling layers using nn.Upsample
self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
# Decoder (Expanding Path)
self.up_conv1 = DoubleConv(1024 + 512, 512)
self.up_conv2 = DoubleConv(512 + 256, 256)
self.up_conv3 = DoubleConv(256 + 128, 128)
self.up_conv4 = DoubleConv(128 + 64, 64)
# Final 1x1 convolution to get desired number of output channels
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
x1 = self.down_conv1(x)
x2 = self.down_conv2(self.maxpool(x1))
x3 = self.down_conv3(self.maxpool(x2))
x4 = self.down_conv4(self.maxpool(x3))
x5 = self.down_conv5(self.maxpool(x4))
x = self.upsample(x5)
x = torch.cat([x4, x], dim=1)
x = self.up_conv1(x)
x = self.upsample(x)
x = torch.cat([x3, x], dim=1)
x = self.up_conv2(x)
x = self.upsample(x)
x = torch.cat([x2, x], dim=1)
x = self.up_conv3(x)
x = self.upsample(x)
x = torch.cat([x1, x], dim=1)
x = self.up_conv4(x)
return self.final_conv(x)
class PatchGANDiscriminator(nn.Module, PyTorchModelHubMixin):
def __init__(self, in_channels=6):
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
)
def forward(self, x):
return self.layers(x)