#!/usr/bin/env python3
# Copyright 2019 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title :utils/batchnorm_layer.py
# @author :ch
# @contact :henningc@ethz.ch
# @created :09/02/2019
# @version :1.0
# @python_version :3.6.8
"""
Batch Normalization
-------------------
Implementation of a hypernet compatible batchnorm layer.
The joint use of batch-normalization and hypernetworks is not straight forward,
mainly due to the statistics accumulated by the batch-norm operation which
expect the weights of the main network to only change slowly. If a hypernetwork
replaces the whole set of weights, the statistics previously estimated by the
batch-norm layer might be completely off.
To circumvent this problem, we provide multiple solutions:
- In a continual learning setting with one set of weights per task, we can
simply estimate and store statistics per task (hence, the batch-norm
operation has to be conditioned on the task).
- The statistics are distilled into the hypernetwork. This would require
the addition of an extra loss term.
- The statistics can be treated as parameters that are outputted by the
hypernetwork. In this case, nothing enforces that these "statistics"
behave similar to statistics that would result from a running estimate
(hence, the resulting operation might have nothing in common with batch-
norm).
- Always use the statistics estimated on the current batch.
Note, we also provide the option of turning off the statistics, in which case
the statistics will be set to zero mean and unit variance. This is helpful when
interpreting batch-normalization as a general form of gain modulation (i.e.,
just applying a shift and scale to neural activities).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from warnings import warn
[docs]class BatchNormLayer(nn.Module):
r"""Hypernetwork-compatible batch-normalization layer.
Note, batch normalization performs the following operation
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \
\gamma + \beta
This class allows to deviate from this standard implementation in order to
provide the flexibility required when using hypernetworks. Therefore, we
slightly change the notation to
.. math::
y = \frac{x - m_{\text{stats}}^{(t)}}{\sqrt{v_{\text{stats}}^{(t)} + \
\epsilon}} * \gamma^{(t)} + \beta^{(t)}
We use this notation to highlight that the running statistics
:math:`m_{\text{stats}}^{(t)}` and :math:`v_{\text{stats}}^{(t)}` are not
necessarily estimates resulting from mean and variance computation but might
be learned parameters (e.g., the outputs of a hypernetwork).
We additionally use the superscript :math:`(t)` to denote that the gain
:math:`\gamma`, offset :math:`\beta` and statistics may be dynamically
selected based on some external context information.
This class provides the possibility to checkpoint statistics
:math:`m_{\text{stats}}^{(t)}` and :math:`v_{\text{stats}}^{(t)}`, but
**not** gains and offsets.
.. note::
If context-dependent gains :math:`\gamma^{(t)}` and offsets
:math:`\beta^{(t)}` are required, then they have to be maintained
externally, e.g., via a task-conditioned hypernetwork (see
`this paper`_ for an example) and passed to the :meth:`forward` method.
.. _this paper: https://arxiv.org/abs/1906.00695
"""
def __init__(self, num_features, momentum=0.1, affine=True,
track_running_stats=True, frozen_stats=False,
learnable_stats=False):
r"""
Args:
num_features: See argument ``num_features``, for instance, of class
:class:`torch.nn.BatchNorm1d`.
momentum: See argument ``momentum`` of class
:class:`torch.nn.BatchNorm1d`.
affine: See argument ``affine`` of class
:class:`torch.nn.BatchNorm1d`. If set to :code:`False`, the
input activity will simply be "whitened" according to the
applied layer statistics (except if gain :math:`\gamma` and
offset :math:`\beta` are passed to the :meth:`forward` method).
Note, if ``learnable_stats`` is :code:`False`, then setting
``affine`` to :code:`False` results in no learnable weights for
this layer (running stats might still be updated, but not via
gradient descent).
Note, even if this option is ``False``, one may still pass a
gain :math:`\gamma` and offset :math:`\beta` to the
:meth:`forward` method.
track_running_stats: See argument ``track_running_stats`` of class
:class:`torch.nn.BatchNorm1d`.
frozen_stats: If ``True``, the layer statistics are frozen at their
initial values of :math:`\gamma = 1` and :math:`\beta = 0`,
i.e., layer activity will not be whitened.
Note, this option requires ``track_running_stats`` to be set to
``False``.
learnable_stats: If ``True``, the layer statistics are initialized
as learnable parameters (:code:`requires_grad=True`).
Note, these extra parameters will be maintained internally and
not added to the :attr:`weights`. Statistics can always be
maintained externally and passed to the :meth:`forward` method.
Note, this option requires ``track_running_stats`` to be set to
``False``.
"""
super(BatchNormLayer, self).__init__()
if learnable_stats:
# FIXME We need our custom stats computation for this.
# The running stats updated by `torch.nn.functional.batch_norm` do
# not allow backpropagation.
# See here on how they are computed:
# https://github.com/pytorch/pytorch/blob/96fe2b4ecbbd02143d95f467655a2d697282ac32/aten/src/ATen/native/Normalization.cpp#L137
raise NotImplementedError('Option "learnable_stats" has not been ' +
'implemented yet!')
if momentum is None:
# If one wants to implement this, then please note that the
# attribute `num_batches_tracked` has to be added. Also, note the
# extra code for computing the momentum value in the forward method
# of class `_BatchNorm`:
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#_BatchNorm
raise NotImplementedError('This reimplementation of PyTorch its ' +
'batchnorm layer does not support ' +
'setting "momentum" to None.')
if learnable_stats and track_running_stats:
raise ValueError('Option "track_running_stats" must be set to ' +
'False when enabling "learnable_stats".')
if frozen_stats and track_running_stats:
raise ValueError('Option "track_running_stats" must be set to ' +
'False when enabling "frozen_stats".')
self._num_features = num_features
self._momentum = momentum
self._affine = affine
self._track_running_stats = track_running_stats
self._frozen_stats = frozen_stats
self._learnable_stats = learnable_stats
self.register_buffer('_num_stats', torch.tensor(0, dtype=torch.long))
self._weights = nn.ParameterList()
self._param_shapes = [[num_features], [num_features]]
if affine:
# Gamma
self.register_parameter('scale', nn.Parameter( \
torch.Tensor(num_features), requires_grad=True))
# Beta
self.register_parameter('bias', nn.Parameter( \
torch.Tensor(num_features), requires_grad=True))
self._weights.append(self.scale)
self._weights.append(self.bias)
nn.init.ones_(self.scale)
nn.init.zeros_(self.bias)
elif not learnable_stats:
self._weights = None
if learnable_stats:
# Don't forget to add the new params to `self._weights`.
# Don't forget to add shapes to `self._param_shapes`.
raise NotImplementedError()
elif track_running_stats or frozen_stats:
# Note, in case of frozen stats, we just don't update the stats
# initialized here later on.
self.checkpoint_stats()
else:
mname, vname = self._stats_names(0)
self.register_buffer(mname, None)
self.register_buffer(vname, None)
@property
def weights(self):
"""A list of all internal weights of this layer. If all weights are
assumed to be generated externally, then this attribute will be
``None``.
:type: list or None
"""
return self._weights
@property
def param_shapes(self):
"""A list of list of integers. Each list represents the shape of a
parameter tensor.
Note, this attribute is independent of the attribute :attr:`weights`,
it always comprises the shapes of all weight tensors as if the network
would be stand-alone (i.e., no weights being passed to the
:meth:`forward` method).
Note, unless ``learnable_stats`` is enabled, the layer statistics are
not considered here.
:type: list
"""
return self._param_shapes
@property
def hyper_shapes(self):
r"""A list of list of integers. Each list represents the shape of a
weight tensor that can be passed to the :meth:`forward` method. If all
weights are maintained internally, then this attribute will be ``None``.
Specifically, this attribute is controlled by the argument ``affine``.
If ``affine`` is ``True``, this attribute will be ``None``. Otherwise
this attribute contains the shape of :math:`\gamma` and :math:`\beta`.
:type: list or None
"""
# FIXME not implemented attribute. Do we even need the attribute, given
# that all components are individually passed to the forward method?
raise NotImplementedError('Not implemented yet!')
return self._hyper_shapes
@property
def num_stats(self):
r"""The number :math:`T` of internally managed statistics
:math:`\{(m_{\text{stats}}^{(1)}, v_{\text{stats}}^{(1)}), \dots, \
(m_{\text{stats}}^{(T)}, v_{\text{stats}}^{(T)}) \}`. This number is
incremented everytime the method :meth:`checkpoint_stats` is called.
:type: int
"""
return self._num_stats
[docs] def forward(self, inputs, running_mean=None, running_var=None, weight=None,
bias=None, stats_id=None):
r"""Apply batch normalization to given layer activations.
Based on the state if this module (attribute :attr:`training`), the
configuration of this layer and the parameters currently passed, the
behavior of this function will be different.
The core of this method still relies on the function
:func:`torch.nn.functional.batch_norm`. In the following we list the
different behaviors of this method based on the context.
**In training mode:**
We first consider the case that this module is in training mode, i.e.,
:meth:`torch.nn.Module.train` has been called.
Usually, during training, the running statistics are not used when
computing the output, instead the statistics computed on the current
batch are used (denoted by *use batch stats* in the table below).
However, the batch statistics are typically updated during training
(denoted by *update running stats* in the table below).
The above described scenario would correspond to passing batch
statistics to the function :func:`torch.nn.functional.batch_norm` and
setting the parameter ``training`` to ``True``.
+----------------------+---------------------+-------------------------+
| **training mode** | **use batch stats** | **update running stats**|
+----------------------+---------------------+-------------------------+
| given stats | Yes | Yes |
+----------------------+---------------------+-------------------------+
| track running stats | Yes | Yes |
+----------------------+---------------------+-------------------------+
| frozen stats | No | No |
+----------------------+---------------------+-------------------------+
| learnable stats | Yes | Yes [1]_ |
+----------------------+---------------------+-------------------------+
|no track running stats| Yes | No |
+----------------------+---------------------+-------------------------+
The meaning of each row in this table is as follows:
- **given stats**: External stats are provided via the parameters
``running_mean`` and ``running_var``.
- **track running stats**: If ``track_running_stats`` was set to
``True`` in the constructor and no stats were given.
- **frozen stats**: If ``frozen_stats`` was set to ``True`` in the
constructor and no stats were given.
- **learnable stats**: If ``learnable_stats`` was set to ``True`` in
the constructor and no stats were given.
- **no track running stats**: If none of the above options apply,
then the statistics will always be computed from the current batch
(also in eval mode).
.. note::
If provided, running stats specified via ``running_mean`` and
``running_var`` always have priority.
.. [1] We use a custom implementation to update the running statistics,
that is compatible with backpropagation.
**In evaluation mode:**
We now consider the case that this module is in evaluation mode, i.e.,
:meth:`torch.nn.Module.eval` has been called.
Here is the same table as above just for the evaluation mode.
+----------------------+---------------------+-------------------------+
| **evaluation mode** | **use batch stats** | **update running stats**|
+----------------------+---------------------+-------------------------+
| track running stats | No | No |
+----------------------+---------------------+-------------------------+
| frozen stats | No | No |
+----------------------+---------------------+-------------------------+
| learnable stats | No | No |
+----------------------+---------------------+-------------------------+
| given stats | No | No |
+----------------------+---------------------+-------------------------+
|no track running stats| Yes | No |
+----------------------+---------------------+-------------------------+
Args:
inputs: The inputs to the batchnorm layer.
running_mean (optional): Running mean stats
:math:`m_{\text{stats}}`. This option has priority, i.e., any
internally maintained statistics are ignored if given.
.. note::
If specified, then ``running_var`` also has to be specified.
running_var (optional): Similar to option ``running_mean``, but for
the running variance stats :math:`v_{\text{stats}}`
.. note::
If specified, then ``running_mean`` also has to be
specified.
weight (optional): The gain factors :math:`\gamma`. If given, any
internal gains are ignored. If option ``affine`` was set to
``False`` in the constructor and this option remains ``None``,
then no gains are multiplied to the "whitened" inputs.
bias (optional): The behavior of this option is similar to option
``weight``, except that this option represents the offsets
:math:`\beta`.
stats_id: This argument is optional except if multiple running
stats checkpoints exist (i.e., attribute :attr:`num_stats` is
greater than 1) and no running stats have been provided to this
method.
.. note::
This argument is ignored if running stats have been passed.
Returns:
The layer activation ``inputs`` after batch-norm has been applied.
"""
assert(running_mean is None and running_var is None or \
running_mean is not None and running_var is not None)
if not self._affine:
if weight is None or bias is None:
raise ValueError('Layer was generated in non-affine mode. ' +
'Therefore, arguments "weight" and "bias" ' +
'may not be None.')
# No gains given but we have internal gains.
# Otherwise, if no gains are given we leave `weight` as None.
if weight is None and self._affine:
weight = self.scale
if bias is None and self._affine:
bias = self.bias
stats_given = running_mean is not None
if (running_mean is None or running_var is None):
if stats_id is None and self.num_stats > 1:
raise ValueError('Parameter "stats_id" is not defined but ' +
'multiple running stats are available.')
elif self._track_running_stats:
if stats_id is None:
stats_id = 0
assert(stats_id < self.num_stats)
rm, rv = self.get_stats(stats_id)
if running_mean is None:
running_mean = rm
if running_var is None:
running_var = rv
elif stats_id is not None:
warn('Parameter "stats_id" is ignored since running stats have ' +
'been provided.')
momentum = self._momentum
if stats_given or self._track_running_stats:
return F.batch_norm(inputs, running_mean, running_var,
weight=weight, bias=bias,
training=self.training, momentum=momentum)
if self._learnable_stats:
raise NotImplementedError()
if self._frozen_stats:
return F.batch_norm(inputs, running_mean, running_var,
weight=weight, bias=bias, training=False)
# TODO implement scale and shift here. Note, that `running_mean` and
# `running_var` are always 0 and 1, resp. Therefore, the call to
# `F.batch_norm` is a waste of computation.
#ret = inputs
#if weight is not None:
# # Multiply `ret` with `weight` such that dimensions are
# # respected.
# pass
#if bias is not None:
# # Add `bias` to modified `ret` such that dimensions are
# # respected.
# pass
#return ret
else:
assert(not self._track_running_stats)
# Always compute statistics based on current batch.
return F.batch_norm(inputs, None, None, weight=weight, bias=bias,
training=True, momentum=momentum)
[docs] def checkpoint_stats(self, device=None):
"""Buffers for a new set of running stats will be registered.
Calling this function will also increment the attribute
:attr:`num_stats`.
Args:
device (optional): If not provided, the newly created statistics
will either be moved to the device of the most recent statistics
or to CPU if no prior statistics exist.
"""
assert(self._track_running_stats or \
self._frozen_stats and self._num_stats == 0)
if device is None:
if self.num_stats > 0:
mname_old, _ = self._stats_names(self._num_stats-1)
device = getattr(self, mname_old).device
if self._learnable_stats:
raise NotImplementedError()
mname, vname = self._stats_names(self._num_stats)
self._num_stats += 1
self.register_buffer(mname, torch.zeros(self._num_features,
device=device))
self.register_buffer(vname, torch.ones(self._num_features,
device=device))
[docs] def get_stats(self, stats_id=None):
"""Get a set of running statistics (means and variances).
Args:
stats_id (optional): ID of stats. If not provided, the most recent
stats are returned.
Returns:
(tuple): Tuple containing:
- **running_mean**
- **running_var**
"""
if stats_id is None:
stats_id = self.num_stats - 1
assert(stats_id < self.num_stats)
mname, vname = self._stats_names(stats_id)
running_mean = getattr(self, mname)
running_var = getattr(self, vname)
return running_mean, running_var
def _stats_names(self, stats_id):
"""Get the buffer names for mean and variance statistics depending on
the ``stats_id``, i.e., the ID of the stats checkpoint.
Args:
stats_id: ID of stats.
Returns:
(tuple): Tuple containing:
- **mean_name**
- **var_name**
"""
mean_name = 'mean_%d' % stats_id
var_name = 'var_%d' % stats_id
return mean_name, var_name
if __name__ == '__main__':
pass