#!/usr/bin/env python3
# Copyright 2019 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title :utils/gan_helpers.py
# @author :ch
# @contact :henningc@ethz.ch
# @created :12/17/2019
# @version :1.0
# @python_version :3.6.9
"""
Helper functions for training Generative Adversarial Networks
-------------------------------------------------------------
A collection of helper functions that are useful and general for GAN training,
e.g., several GAN losses.
"""
import torch
import torch.nn.functional as F
[docs]def dis_loss(logit_real, logit_fake, loss_choice):
"""Compute the loss for the discriminator.
Note, only the discriminator weights should be updated using this loss.
Args:
logit_real: Outputs of the discriminator after seeing real samples.
.. note::
We assume a linear output layer.
logit_fake: Outputs of the discriminator after seeing fake samples.
.. note::
We assume a linear output layer.
loss_choice (int): Define what loss function is used to train the GAN.
Note, the choice of loss function also influences how the output
of the discriminator network if reinterpreted or squashed (either
between ``[0,1]`` or an arbitrary real number).
The following choices are available.
- ``0``: Vanilla GAN (Goodfellow et al., 2014). Non-saturating
loss version. Note, we additionally apply one-sided label
smoothing for this loss.
- ``1``: Traditional LSGAN (Mao et al., 2018). See eq. 14 of
the paper. This loss corresponds to a parameter
choice :math:`a=0`, :math:`b=1` and :math:`c=1`.
- ``2``: Pearson Chi^2 LSGAN (Mao et al., 2018). See eq. 13.
Parameter choice: :math:`a=-1`, :math:`b=1` and :math:`c=0`.
- ``3``: Wasserstein GAN (Arjovski et al., 2017).
Returns:
The discriminator loss.
"""
if loss_choice == 0: # Vanilla GAN
# We use the binary cross entropy.
# Note, we use one-sided label-smoothing.
fake = torch.sigmoid(logit_fake)
real = torch.sigmoid(logit_real)
r_loss = F.binary_cross_entropy(real, 0.9*torch.ones_like(real))
f_loss = F.binary_cross_entropy(fake, torch.zeros_like(fake))
elif loss_choice == 1: # Traditional LSGAN
r_loss = F.mse_loss(logit_real, torch.ones_like(logit_real))
f_loss = F.mse_loss(logit_fake, torch.zeros_like(logit_fake))
elif loss_choice == 2: # Pearson Chi^2 LSGAN
r_loss = F.mse_loss(logit_real, torch.ones_like(logit_real))
f_loss = F.mse_loss(logit_fake, -torch.ones_like(logit_fake))
else: # WGAN
r_loss = -logit_real.mean()
f_loss = logit_fake.mean()
return (r_loss + f_loss)
[docs]def gen_loss(logit_fake, loss_choice):
"""Compute the loss for the generator.
Args:
(....): See docstring of function :func:`dis_loss`.
Returns:
The generator loss.
"""
if loss_choice == 0: # Vanilla GAN
# We use the -log(D) trick.
fake = torch.sigmoid(logit_fake)
return F.binary_cross_entropy(fake, torch.ones_like(fake))
elif loss_choice == 1: # Traditional LSGAN
return F.mse_loss(logit_fake, torch.ones_like(logit_fake))
elif loss_choice == 2: # Pearson Chi^2 LSGAN
return F.mse_loss(logit_fake, torch.zeros_like(logit_fake))
else: # WGAN
return -logit_fake.mean()
[docs]def accuracy(logit_real, logit_fake, loss_choice):
"""The accuracy of the discriminator.
It is computed based on the assumption that values greater than a threshold
are classified as real.
Note, the accuracy measure is only well defined for the Vanilla GAN.
Though, we just look at generally preferred value ranges and generalize
the concept of accuracy to the other GAN formulations using the
following thresholds:
- ``0.5`` for Vanilla GAN and Traditional LSGAN
- ``0`` for Pearson Chi^2 LSGAN and WGAN.
Args:
(....): See docstring of function :func:`dis_loss`.
Returns:
The relative accuracy of the discriminator.
"""
T = 0.5 if loss_choice < 2 else 0.0
#if loss_choice == 0:
# fake = torch.sigmoid(logit_fake)
# real = torch.sigmoid(logit_real)
# Note, values above 0 will be above 0.5 after being passed through a
# softmax. Therefore, we take the threshold 0 for logit activations, if the
# logits are supposed to be passed through a softmax.
T = 0 if loss_choice == 0 else T
n_correct = (logit_real > T).float().sum() + (logit_fake <= T).float().sum()
return n_correct / (logit_real.numel() + logit_fake.numel())
[docs]def concat_mean_stats(inputs):
"""Add mean statistics to discriminator input.
GANs often run into mode collapse since the discriminator sees every
sample in isolation. I.e., it cannot detect whether all samples in a batch
do look alike.
A simple way to allow the discriminator to have access to batch statistics
is to simply concatenate the mean (across batch dimension) of all
discriminator samples to each sample.
Args:
inputs: The input batch to the discriminator.
Returns:
The modified input batch.
"""
stats = torch.mean(inputs, 0, keepdim=True)
stats = stats.expand(inputs.size())
return torch.cat([stats, inputs], dim=1)
if __name__ == '__main__':
pass