This subpackage contains common helper functions to a variety of problems (e.g., PyTorch checkpointing, special layers, computing diagonal Fisher matrices, …).
Implementation of a hypernet compatible batchnorm layer.
The joint use of batch-normalization and hypernetworks is not straight forward,
mainly due to the statistics accumulated by the batch-norm operation which
expect the weights of the main network to only change slowly. If a hypernetwork
replaces the whole set of weights, the statistics previously estimated by the
batch-norm layer might be completely off.
To circumvent this problem, we provide multiple solutions:
In a continual learning setting with one set of weights per task, we can
simply estimate and store statistics per task (hence, the batch-norm
operation has to be conditioned on the task).
The statistics are distilled into the hypernetwork. This would require
the addition of an extra loss term.
The statistics can be treated as parameters that are outputted by the
hypernetwork. In this case, nothing enforces that these “statistics”
behave similar to statistics that would result from a running estimate
(hence, the resulting operation might have nothing in common with batch-
norm).
Always use the statistics estimated on the current batch.
Note, we also provide the option of turning off the statistics, in which case
the statistics will be set to zero mean and unit variance. This is helpful when
interpreting batch-normalization as a general form of gain modulation (i.e.,
just applying a shift and scale to neural activities).
Note, batch normalization performs the following operation
This class allows to deviate from this standard implementation in order to
provide the flexibility required when using hypernetworks. Therefore, we
slightly change the notation to
We use this notation to highlight that the running statistics
and are not
necessarily estimates resulting from mean and variance computation but might
be learned parameters (e.g., the outputs of a hypernetwork).
We additionally use the superscript to denote that the gain
, offset and statistics may be dynamically
selected based on some external context information.
This class provides the possibility to checkpoint statistics
and , but
not gains and offsets.
Note
If context-dependent gains and offsets
are required, then they have to be maintained
externally, e.g., via a task-conditioned hypernetwork (see
this paper for an example) and passed to the forward() method.
Parameters:
num_features – See argument num_features, for instance, of class
torch.nn.BatchNorm1d.
See argument affine of class
torch.nn.BatchNorm1d. If set to False, the
input activity will simply be “whitened” according to the
applied layer statistics (except if gain and
offset are passed to the forward() method).
Note, if learnable_stats is False, then setting
affine to False results in no learnable weights for
this layer (running stats might still be updated, but not via
gradient descent).
Note, even if this option is False, one may still pass a
gain and offset to the
forward() method.
track_running_stats – See argument track_running_stats of class
torch.nn.BatchNorm1d.
frozen_stats –
If True, the layer statistics are frozen at their
initial values of and ,
i.e., layer activity will not be whitened.
Note, this option requires track_running_stats to be set to
False.
learnable_stats –
If True, the layer statistics are initialized
as learnable parameters (requires_grad=True).
Note, these extra parameters will be maintained internally and
not added to the weights. Statistics can always be
maintained externally and passed to the forward() method.
Note, this option requires track_running_stats to be set to
False.
Buffers for a new set of running stats will be registered.
Calling this function will also increment the attribute
num_stats.
Parameters:
device (optional) – If not provided, the newly created statistics
will either be moved to the device of the most recent statistics
or to CPU if no prior statistics exist.
Apply batch normalization to given layer activations.
Based on the state if this module (attribute training), the
configuration of this layer and the parameters currently passed, the
behavior of this function will be different.
The core of this method still relies on the function
torch.nn.functional.batch_norm(). In the following we list the
different behaviors of this method based on the context.
In training mode:
We first consider the case that this module is in training mode, i.e.,
torch.nn.Module.train() has been called.
Usually, during training, the running statistics are not used when
computing the output, instead the statistics computed on the current
batch are used (denoted by use batch stats in the table below).
However, the batch statistics are typically updated during training
(denoted by update running stats in the table below).
The above described scenario would correspond to passing batch
statistics to the function torch.nn.functional.batch_norm() and
setting the parameter training to True.
The meaning of each row in this table is as follows:
given stats: External stats are provided via the parameters
running_mean and running_var.
track running stats: If track_running_stats was set to
True in the constructor and no stats were given.
frozen stats: If frozen_stats was set to True in the
constructor and no stats were given.
learnable stats: If learnable_stats was set to True in
the constructor and no stats were given.
no track running stats: If none of the above options apply,
then the statistics will always be computed from the current batch
(also in eval mode).
Note
If provided, running stats specified via running_mean and
running_var always have priority.
In evaluation mode:
We now consider the case that this module is in evaluation mode, i.e.,
torch.nn.Module.eval() has been called.
Here is the same table as above just for the evaluation mode.
evaluation mode
use batch stats
update running stats
track running stats
No
No
frozen stats
No
No
learnable stats
No
No
given stats
No
No
no track running stats
Yes
No
Parameters:
inputs – The inputs to the batchnorm layer.
running_mean (optional) –
Running mean stats
. This option has priority, i.e., any
internally maintained statistics are ignored if given.
Note
If specified, then running_var also has to be specified.
running_var (optional) –
Similar to option running_mean, but for
the running variance stats
Note
If specified, then running_mean also has to be
specified.
weight (optional) – The gain factors . If given, any
internal gains are ignored. If option affine was set to
False in the constructor and this option remains None,
then no gains are multiplied to the “whitened” inputs.
bias (optional) – The behavior of this option is similar to option
weight, except that this option represents the offsets
.
stats_id –
This argument is optional except if multiple running
stats checkpoints exist (i.e., attribute num_stats is
greater than 1) and no running stats have been provided to this
method.
Note
This argument is ignored if running stats have been passed.
Returns:
The layer activation inputs after batch-norm has been applied.
A list of list of integers. Each list represents the shape of a
weight tensor that can be passed to the forward() method. If all
weights are maintained internally, then this attribute will be None.
Specifically, this attribute is controlled by the argument affine.
If affine is True, this attribute will be None. Otherwise
this attribute contains the shape of and .
A list of list of integers. Each list represents the shape of a
parameter tensor.
Note, this attribute is independent of the attribute weights,
it always comprises the shapes of all weight tensors as if the network
would be stand-alone (i.e., no weights being passed to the
forward() method).
Note, unless learnable_stats is enabled, the layer statistics are
not considered here.
This file has a collection of helper functions that can be used to specify
command-line arguments. In particular, arguments that are necessary for
multiple experiments (even though with different default values) should be
specified here, such that we do not define arguments (and their help texts)
multiple times.
All functions specified here are helper functions for a simulation specific
argument parser such as cifar.train_args.parse_cmd_arguments().
DO NEVER CHANGE DEFAULT VALUES. Instead, add a keyword argument to the
corresponding method, that allows you to change the default value, when you
call the method.
This is a helper method of the method parse_cmd_arguments (or more
specifically an auxillary method to train_args()) to add arguments to
an argument group for options specific to a main network that should act as
a generator.
Arguments specified in this function:
latent_dim
latent_std
Parameters:
agroup – The argument group returned by, for instance, function
main_net_args().
dhmlp_arch (str) – Default value of option hmlp_arch.
show_cond_emb_size (bool) – Whether the option cond_emb_size should be
provided.
dcond_emb_size (int) – Default value of option cond_emb_size.
dchmlp_chunk_size (int) – Default value of option chmlp_chunk_size.
dchunk_emb_size (int) – Default value of option chunk_emb_size.
show_use_cond_chunk_embs (bool) – Whether the option
use_cond_chunk_embs should be provided (if applicable to
network types).
dhdeconv_shape (str) – Default value of option hdeconv_shape.
prefix (str, optional) – If arguments should be instantiated with a
certain prefix. E.g., a setup requires several hypernetworks, that
may need different settings. For instance: prefix='gen_'.
pf_name (str, optional) – A name of type of hypernetwork for which that
prefix is needed. For instance: prefix='generator'.
**kwargs – Keyword arguments to configure options that are common across
main networks (note, a hypernet is just a special main network). See
arguments of main_net_args().
Returns:
The created argument group containing the
desired options.
List of allowed network identifiers. The following
identifiers are considered (note, we also reference the network that
each network type targets):
mlp: mnets.mlp.MLP
lenet: mnets.lenet.LeNet
resnet: mnets.resnet.ResNet
wrn: mnets.wide_resnet.WRN
iresnet: mnets.resnet_imgnet.ResNetIN
zenke: mnets.zenkenet.ZenkeNet
bio_conv_net: mnets.bio_conv_net.BioConvNet
chunked_mlp: mnets.chunk_squeezer.ChunkSqueezer
simple_rnn: mnets.simple_rnn.SimpleRNN
dmlp_arch – Default value of option mlp_arch.
dlenet_type – Default value of option lenet_type.
dcmlp_arch – Default value of option cmlp_arch.
dcmlp_chunk_arch – Default value of option cmlp_chunk_arch.
dcmlp_in_cdim – Default value of option cmlp_in_cdim.
dcmlp_out_cdim – Default value of option cmlp_out_cdim.
dcmlp_cemb_dim – Default value of option cmlp_cemb_dim.
dresnet_block_depth – Default value of option resnet_block_depth.
dresnet_channel_sizes – Default value of option resnet_channel_sizes.
dwrn_block_depth – Default value of option wrn_block_depth.
dwrn_widening_factor – Default value of option wrn_widening_factor.
diresnet_channel_sizes – Default value of option
iresnet_channel_sizes.
diresnet_blocks_per_group – Default value of option
iresnet_blocks_per_group.
dsrnn_rec_layers – Default value of option srnn_rec_layers.
dsrnn_pre_fc_layers – Default value of option srnn_pre_fc_layers.
dsrnn_post_fc_layers – Default value of option srnn_post_fc_layers.
dsrnn_rec_type – Default value of option srnn_rec_type.
show_net_act (bool) – Whether the option net_act should be provided.
dnet_act – Default value of option net_act.
show_no_bias (bool) – Whether the option no_bias should be provided.
show_dropout_rate (bool) – Whether the option dropout_rate should be
provided.
ddropout_rate – Default value of option dropout_rate.
show_specnorm (bool) – Whether the option specnorm should be provided.
show_batchnorm (bool) – Whether the option batchnorm should be
provided.
show_no_batchnorm (bool) – Whether the option no_batchnorm should be
provided.
show_bn_no_running_stats (bool) – Whether the option
bn_no_running_stats should be provided.
show_bn_distill_stats (bool) – Whether the option bn_distill_stats
should be provided.
show_bn_no_stats_checkpointing (bool) – Whether the option
bn_no_stats_checkpointing should be provided.
prefix (optional) – If arguments should be instantiated with a certain
prefix. E.g., a setup requires several main network, that may need
different settings. For instance: prefix=:code:prefix=’gen_’.
pf_name (optional) – A name of the type of main net for which that prefix
is needed. For instance: prefix=:code:’generator’.
Returns:
The created argument group, in case more options should be added.
big_data – If the program processes big datasets that need to be loaded
from disk on the fly. In this case, more options are provided.
synthetic_data – If data is randomly generated, then we want to decouple
this randomness from the training randomness.
show_plots – Whether the option show_plots should be provided.
no_cuda – If True, the user has to explicitly set the flag –use_cuda
rather than using CUDA by default.
dout_dir (optional) – Default value of option out_dir. If None,
the default value will be ./out/run_<YY>-<MM>-<DD>_<hh>-<mm>-<ss>
that contains the current date and time.
show_publication_style – Whether the option publication_style should be
provided.
Returns:
The created argument group, in case more options should be added.
show_lr – Whether the lr - learning rate - argument should be shown.
Might not be desired if individual learning rates per optimizer
should be specified.
dlr – Default value for option lr.
show_epochs – Whether the epochs argument should be shown.
depochs – Default value for option epochs.
dbatch_size – Default value for option batch_size.
dn_iter – Default value for option n_iter.
show_use_adam – Whether the use_adam argument should be shown. Will
also show the adam_beta1 argument.
dadam_beta1 – Default value for option adam_beta1.
show_use_rmsprop – Whether the use_rmsprop argument should be shown.
show_use_adadelta – Whether the use_adadelta argument should be shown.
show_use_adagrad – Whether the use_adagrad argument should be shown.
show_clip_grad_value – Whether the clip_grad_value argument should be
shown.
show_clip_grad_norm – Whether the clip_grad_norm argument should be
shown.
show_adam_beta1 – Whether the adam_beta1 argument should be
shown. Note, this argument is also shown when show_use_adam is
True.
show_momentum – Whether the momentum argument should be
shown.
Returns:
The created argument group, in case more options should be added.
Implementation of a layer that can apply context-dependent modulation on
the level of neuronal computation.
The layer consists of two parameter vectors: gains
and shifts , whereas gains represent a multiplicative
modulation of input activations and shifts an additive modulation,
respectively.
Note, the weight vectors and might
also be passed to the forward() method, where one may pass a separate
set of parameters for each sample in the input batch.
Example
Assume that a ContextModLayer is applied between a linear
(fully-connected) layer
with input
and a nonlinear activation function
.
Number of units in the layer (size of
parameter vectors and ).
In case a tuple of integers is provided, the gain
and shift parameters will
become multidimensional tensors with the shape being prescribed
by num_features. Please note the broadcasting rules as
and are simply multiplied
or added to the input.
Example
Consider the output of a convolutional layer with output shape
[B,C,W,H]. In case there should be a scalar gain and shift
per feature map, num_features could be [C,1,1] or
[1,C,1,1] (one might also pass a shape [B,C,1,1] to the
forward() method to apply separate shifts and gains per
sample in the batch).
Alternatively, one might want to provide shift and gain per
output unit, i.e., num_features should be [C,W,H]. Note,
that due to weight sharing, all output activities within a
feature map are computed using the same weights, which is why it
is common practice to share shifts and gains within a feature
map (e.g., in Spatial Batch-Normalization).
no_weights (bool) – If True, the layer will have no trainable weights
( and ). Hence, weights are
expected to be passed to the forward() method.
If activated, this option will apply
a constant offset of 1 to all gains, i.e., the computation becomes
When could that be useful? In case the gains and shifts are
generated by the same hypernetwork, a meaningful initialization
might be difficult to achieve (e.g., such that gains are close to 1
and shifts are close to 0 at the beginning). Therefore, one might
initialize the hypernetwork such that all outputs are close to zero
at the beginning and the constant shift ensures that meaningful
gains are applied.
If activated, this option will
enforce poitive gain modulation by sending the gain weights
through a softplus function (scaled by ,
see softplus_scale).
softplus_scale (float) – If option apply_gain_softplus is True,
then this will determine the sclae of the softplus function.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Buffers for a new checkpoint will be registered and the current weights
will be copied into them. Additionally, the current weights will be
reinitialized (gains to 1 and shifts to 0).
Calling this function will also increment the attribute
num_ckpts.
Note
This method uses the method torch.nn.Module.register_buffer()
rather than the method torch.nn.Module.register_parameter() to
create checkpoints. The reason is, that we don’t want the
checkpoints to appear as trainable weights (when calling
torch.nn.Module.parameters()). However, that means that
training on checkpointed weights cannot be continued unless they are
copied back into an actual torch.nn.Parameter object.
Parameters:
device (optional) – If not provided, the newly created checkpoint
will be moved to the device of the current weights.
no_reinit (bool) – If True, the actual weights will not
be reinitialized.
Weights that should be used instead of the internally
maintained once (determined by attribute weights). Note,
if no_weights was True in the constructor, then this
parameter is mandatory.
Usually, the shape of the passed weights should follow the
attribute param_shapes, which is a tuple of shapes
[[num_features],[num_features]] (at least for linear
layers, see docstring of argument num_features in the
constructor for more details). However, one may also
specify a seperate set of context-mod parameters per input
sample. Assume x has shape [num_samples,num_features].
Then weights may have the shape
[[num_samples,num_features],[num_samples,num_features]].
Get the current (or a set of checkpointed) weights of this context-
mod layer.
Parameters:
ckpt_id (optional) – ID of checkpoint. If not provided, the current
set of weights is returned.
If ckpt_id==self.num_ckpts, then this method also
returns the current weights, as the checkpoint has not been
created yet.
A list of list of integers. Each list represents the shape of a
parameter tensor. Note, this attribute is independent of the attribute
weights, it always comprises the shapes of all weight tensors as
if the network would be stand- alone (i.e., no weights being passed to
the forward() method).
Note
The weights passed to the forward() method might deviate
from these shapes, as we allow passing a distinct set of
parameters per sample in the input batch.
Depending on the user configuration, gains might be preprocessed before
applied for context-modulation (e.g., see attributes
gain_offset_applied or gain_softplus_applied). This
method transforms raw gains such that they can be applied to the network
activation.
Note
This method is called by the forward() to transform given
gains.
Compute estimates of the diagonal elements of the Fisher information
matrix, as needed as importance-weights by elastic weight consolidation
(EWC).
The Fisher matrix for a conditional distribution
(i.e., the model likelihood for a model with parameters ) is
defined as follows at location
In practice, we are often interested in the Fisher averaged over locations
Since the model is trained, such that in-distribution the model likelihood
and the ground-truth likelihood
agree, people often refer to the empirical Fisher, which
utilizes the dataset for computation and therewith doesn’t require sampling
from the model likelihood. Note, EWC anyway assumes that in-distribution
in order to be able to replace
the Hessian by the Fisher matrix.
Note
This method registers buffers in the given module (storing the
current parameters and the estimate of the Fisher diagonal elements),
i.e., the mnet if hnet is None, otherwise the hnet.
Parameters:
task_id – The ID of the current task, needed to store the computed
tensors with a unique name. When hnet is given, it is used as
input to the hnet forward method to select the current task
embedding.
data – A data handler. We will compute the Fisher estimate across the
whole training set (except n_max is specified).
params – A list of parameter tensors from the module of which we aim to
compute the Fisher for. If hnet is given, then these are assumed
to be the “theta” parameters, that we pass to the forward function
of the hypernetwork. Otherwise, these are the “weights” passed to
the forward method of the main network.
Note, they might not be detached from their original parameters,
because we use backward() on the computational graph to read out
the .grad variable.
Note, the order in which these parameters are passed to this method
and the corresponding EWC loss function must not change, because
the index within the “params” list will be used as unique
identifier.
device – Current PyTorch device.
mnet – The main network. If hnet is None, then params are
assumed to belong to this network. The fisher estimate will be
computed accordingly.
Note, params might be the output of a task-conditioned
hypernetwork, i.e., weights for a specific task. In this case,
“online”-EWC doesn’t make much sense, as we don’t follow the
Bayesian view of using the old task weights as prior for the current
ones. Instead, we have a new set of weights for all tasks.
hnet (optional) – If given, params is assumed to correspond to the
unconditional weights (which does not include, for
instance, task embeddings) of the hypernetwork. In this case, the
diagonal Fisher entries belong to weights of the hypernetwork. The
Fisher will then be computed based on the probability
, where task_id is just a
constant input (representing the corresponding conditional weights,
e.g., task embedding) in addition to the training samples .
empirical_fisher – If True, we compute the Fisher based on training
targets.
online – If True, then we use online EWC, hence, there is only one
diagonal Fisher approximation and one target parameter value stored
at the time, rather than for all previous tasks.
gamma – The gamma parameter for online EWC, controlling the gradual decay
of previous tasks.
n_max (optional) – If not -1, this will be the maximum amount of
samples considered for estimating the Fisher.
regression – Whether the task at hand is a classification or regression
task. If True, a regression task is assumed. For simplicity, we
assume the following probabilistic model
with
being the identity matrix. In this case, the only term of the log
probability that influence the gradient is the MSE:
If True, the output of the main network
mnet is expected to be a time series. In particular, we
assume that the output is a tensor of shape [S,N,F],
where S is the length of the time series, N is the batch
size and F is the size of each feature vector (e.g., in
classification, F would be the number of classes).
Let be the
output of the main network. We denote the parameters params by
and the input by (which we do not
consider as random). We use the following decomposition of the
likelihood
Classification: If
denotes the output
of the main network mnet for timestep (assuming
is the most recent hidden state), we assume
Hence, we assume that we can write the negative log-likelihood (NLL)
as follows given a label :
Thus, we simply sum the cross-entropy losses per time-step to
estimate the NLL, which we then backpropagate through in order to
compute the diagonal Fisher elements.
allowed_outputs (optional) – A list of indices, indicating which output
neurons of the main network should be taken into account when
computing the log probability. If not specified, all output neurons
are considered.
custom_forward (optional) –
A function handle that can replace the
default procedure of forwarding samples through the given
network(s).
The signature of this function should be as follows.
hnet is None: @fun(mnet,params,X)
hnet is not None:
@fun(mnet,hnet,task_id,params,X)
where X denotes the input batch to the main network (usually
consisting of a single sample).
Example
Imagine a situation where the main network uses context-
dependent modulation (cmp.
utils.context_mod_layer.ContextModLayer) and the
parameters of these context-mod layers are produced by the
hypernetwork hnet, whereas the remaining weights of the
main network mnet are maintained internally and passed as
argument params to this method.
In particular, we look at a main network that is an instance
of class mnets.mlp.MLP. The forward pass through this
combination of networks should be handled as follows in order
to compute the correct fisher matrix:
A function handle that can replace the default
procedure of computing the negative-log-likelihood (NLL), which is
required to compute the Fisher.
The signature of this function should be as follows:
@fun(Y,T,data,allowed_outputs,empirical_fisher)
where Y are the outputs of the main network. Note,
allowed_outputs have already been applied to Y, if given.
T is the target provided by the dataset data, transformed as
follows:
where batch_ids are the unique identifiers as returned by
option return_ids of method
data.dataset.Dataset.next_train_batch() corresponding to the
provided samples.
Example
In sequential datasets, target sequences T might be padded
to the same length. Though, if the unpadded length should be
used for NLL computation, then the custom_nll function needs
the ability to request this information (sequence length) from
data.
Also, the signatures of custom_forward are expected to be
different.
The signature of this function should be as follows.
hnet is None: @fun(mnet,params,X,data,batch_ids)
hnet is not None:
@fun(mnet,hnet,task_id,params,X,data,batch_ids)
The algorithm Online EWC is based on a Taylor
approximation of the posterior that leads to the following
estimate
Due to the presentation of the algorithm in the paper and inspired
by multiple publicly implementations, we approximate the
regularization strength in practice via
where is a hyperparameter.
If this argument is True, then the sum of Fisher matrices is
properly weighted by the dataset size (independent of argument
n_max).
prior_strength (float or list, optional) – Either a scalar or a list of
Tensors with the same shapes as params. Only applies to
Online EWC. One can specify an offset for all Fisher values, e.g.,
. See argument proper_scaling
for details.
regression_lvar (float) – In regression, this refers to the variance of
the likelihood.
target_manipulator (func, optional) –
A function with signature
T=target_manipulator(T)
That may manipulate the targets coming from the dataset.
See argument custom_forward of function compute_fisher() for more
details.
This is a helper method to quickly retrieve a function handle that manages
the forward pass for a context-modulated main network.
We assume that the interface of the main network is similar to the one of
mnets.mlp.MLP.forward().
Parameters:
mod_weights (optional) – If provided, it is assumed that
compute_fisher() is called with hnet set to None.
Hence, the returned function handle will have the given
context-modulation pattern hard-coded.
If left unspecified, it is assumed that a hnet is passed to
compute_fisher() and that this hnet computes only the
parameters of all context-mod layers.
Compute the EWC regularizer, that can be added to the remaining loss.
Note, the hyperparameter, that trades-off the regularization strength is
not yet multiplied by the loss.
This loss assumes an appropriate use of the method “compute_fisher”. Note,
for the current task “compute_fisher” has to be called after calling this
method.
If online is False, this method implements the loss proposed in eq. (3) in
[EWC2017], except for the missing hyperparameter lambda.
The online EWC implementation follows eq. (8) from [OnEWC2018] (note, that
lambda does not appear in this equation, but it was used in their
experiments).
It is computed based on the assumption that values greater than a threshold
are classified as real.
Note, the accuracy measure is only well defined for the Vanilla GAN.
Though, we just look at generally preferred value ranges and generalize
the concept of accuracy to the other GAN formulations using the
following thresholds:
GANs often run into mode collapse since the discriminator sees every
sample in isolation. I.e., it cannot detect whether all samples in a batch
do look alike.
A simple way to allow the discriminator to have access to batch statistics
is to simply concatenate the mean (across batch dimension) of all
discriminator samples to each sample.
Define what loss function is used to train the GAN.
Note, the choice of loss function also influences how the output
of the discriminator network if reinterpreted or squashed (either
between [0,1] or an arbitrary real number).
The following choices are available.
0: Vanilla GAN (Goodfellow et al., 2014). Non-saturating
loss version. Note, we additionally apply one-sided label
smoothing for this loss.
1: Traditional LSGAN (Mao et al., 2018). See eq. 14 of
the paper. This loss corresponds to a parameter
choice , and .
2: Pearson Chi^2 LSGAN (Mao et al., 2018). See eq. 13.
Parameter choice: , and .
The pseudocode of the algorithm is described in Figure 2 of the paper. The
algorithm uses the Leapfrog algorithm to simulate the Hamiltonian dynamics in
discrete time. Therefore, two crucial hyperparameters are required: the stepsize
and the number of steps . Both hyperparameters have to
be chosen with care and can drastically influence the behavior of HMC. If the
stepsize is too small, we don’t explore the state space
efficiently and waste computation. If it is too big, the numerical error from
the discretization might be come too huge and the acceptance rate rather low. In
addition, we want to choose large enough to obtain good exploration,
but if we set it too large we might loop back to the starting position.
The No-U-Turn-Sampler (NUTS) has been proposed to set automatically,
such that only the stepsize has to be chosen.
Hoffman et al.,
“The No-U-Turn Sampler:
Adaptively Setting Path Lengths in Hamiltonian Monte Carlo”, 2011.
This module provides implementations for both variants, basic HMC and
NUTS. Multiple parallel chains can be simulated via class
MultiChainHMC. For Bayesian Neural Networks, the helper function
nn_pot_energy() can be used to define the potential energy.
Notation
We largely follow the notation from
Neal et al.. The variable of interest,
e.g., model parameters, are encoded by the position vector . In
addition, HMC requires a momentum . The Hamiltonian
consists of two terms, the potential energy and the kinetic energy
with being a symmetric, p.d. “mass”
matrix.
The Hamiltonian dynamics can thus be summarized as
The Leapfrog algorithm is a way to discretize the differential equation above
in a way that is reversible and volumne preserving. The algorithm has two
hyperparameters: the stepsize and the number of steps
. Below, we sketch the algorithm to update momentum and position from
time to time .
We assume a diagonal mass matrix in the position update above.
The position variable should be provided as vector. The weights
of a neural network can be flattend via
mnets.mnet_interface.MainNetInterface.flatten_params().
pot_energy_func (func) –
A function handle computing the potential
energy upon receiving a position . To sample
the weights of a neural network, the helper function
nn_pot_energy() can be used. To sample via HMC from a target
distribution implemented via
torch.distributions.distribution.Distribution, one can
define a function handle as in the following example.
stepsize (float) – The stepsize of the leapfrog()
algorithm.
num_steps (int) – The number of steps in the leapfrog()
algorithm.
inv_mass (float or torch.Tensor) – The inverse “mass” matrix as required
for the computation of the kinetic energy . See argument
inv_mass of function leapfrog() for details.
logger (logging.Logger, optional) – If provided, the progress will be
logged.
log_interval (int) – After how many states the status should be logged.
writer (tensorboardX.SummaryWriter, optional) – A tensorboard writer.
If given, useful simulation data will be logged, like the
developement of the Hamiltonian.
writer_tag (str) – Will be added to the tensorboard tags.
A list containing all position variables (Markov states) visited so
far.
New positions will be added by the method simulate_chain(). To
decrease the memory footprint of objects in this class, the trajectory
can be cleared via method clear_position_trajectory().
Implementation of the Metropolis-Hastings algorithm.
This class implements the basic Metropolis-Hastings algorithm as, for
instance, outlined here (see alg. 1).
The Metropolis-Hastings algorithm is a simple MCMC algorithm. In contrast
to HMC, sampling is slow as positions follow a random walk.
However, the algorithm does not need access to gradient information, which
makes it applicable to a wider range of applications.
We use a normal distribution as proposal,
where denotes the previous position (sample point). Thus, the
proposal is symmetric, and cancels in the MH steps.
The potential energy is expected to be passed as negative log-probability
(up to a constant), such that
A list containing all position variables (Markov states) visited so
far.
New positions will be added by the method simulate_chain(). To
decrease the memory footprint of objects in this class, the trajectory
can be cleared via method clear_position_trajectory().
Wrapper for running multiple HMC chains in parallel.
Samples obtained via an MCMC sampler are highly auto-correlated for two
reasons: (1) the proposal distribution is conditioned on the previous state
and (2) because of rejection (consecutive states are identical). In
addition, it is unclear when the chain is long enough such that sufficient
exploration has been taking place and the sample (excluding initial burn-in)
can be considered an i.i.d. sample from the target distribution. For this
reason, it is recommended to obtain an MCMC sample by running multiple
chains in parrallel, starting from varying initial postitions .
This class provides a simple wrapper to instantiate multiple chains from
HMC (and its subclasses) and provides an interface to easily
simulate those chains.
Parameters:
initial_positions (list or tuple) – A list of initial positions. The
length of this list will determine the number of chains to be
instantiated. Each element is an initial position as described for
argument initial_position of class HMC.
pot_energy_func (func) – See docstring of class HMC. One may
also provide a list of functions. For instance, if the potential
energy of a Bayesian neural network should be computed, there might
be a runtime speedup if each function uses separate model instance.
The of HMC algorithm to be used. The following options
are available:
'hmc': Each chain will be an instance of class HMC.
'nuts': Each chain will be an instance of class NUTS.
**kwargs –
Keyword arguments that will be passed to the constructor when
instantiating each chain. The following particularities should be
noted.
If a writer object is passed, then a chain-specific identifier
is added to the corresponding writer_tag, except if writer
is a string. In this case, we assume writer corresponds to an
output directory and we construct a separate object of class
tensorboardX.SummaryWriter per chain. In the latter case,
the scalars logged across chains are all shown within the same
tensorboard plot and are therefore easier comparable.
If a logger object is passed, then it will only be provided
to the first chain. If a logger should be passed to multiple
chain instances, then a list of objects from class
logging.Logger is required. If entries in this list are
None, then a simple console logger is generated for these
entries that displays the chain’s identity when logging a message.
Simulate the chains to gather a certain number of new positions.
This method simulates the internal chains to add num_states
positions to each considered chain.
Parameters:
num_states (int) – Each considered chain will be simulated for
this amount of HMC steps (see argument n of method
).
num_chains (int or list) – The number of chains to be considered. If
-1, then all chains will be simulated for num_states
steps. Otherwise, the num_chains chains with the lowest
number of states so far (according to attribute
HMC.num_states) is simulated. Alternatively, one may
specify a list of chain indices (numbers between 0 and
num_chains).
num_parallel (int) – How many chains should be simulated in parallel.
If 1, the chains are simulated consecutively (one after
another).
In this class, we implement the efficient version of the NUTS algorithm
(see algorithm 3 in Hoffman et al.).
NUTS eliminates the need to choose the number of Leapfrog steps .
While the algorithm is more computationally expensive than basic HMC, the
reduced hyperparameter effort has been shown to reduce the overall
computational cost (and it requires less human intervention).
As explained in the paper, a good heuristic to set is to choose
the highest number (for given ) before the trajectory loops
back to the initial position , e.g., when the following quantity
becomes negative
Note, this equation assumes the mass matrix is the identity: .
However, this approach is in general not time reversible, therefore NUTS
proposes a recursive agorithm that allows backtracing. NUTS randomly adds
subtrees to a balanced binary tree and stops when any of those subtrees
starts making a “U-turn” (either forward or backward in time). This tree
construction is fully symmetric and therefore reversible.
Note
The NUTS paper also proposes to combine a heuristic approach to adapt
the stepsize together with (e.g., see
algorithm 6 in Hoffman et al.).
Such stepsize adaptation is currently not implemented by this class!
The leapfrog algorithm updates position and momentum
variables by simulating the Hamiltonian dynamics in discrete time for a
time window of size , where is the number of
leapfrog steps num_steps and is the stepsize.
In general, one can call this method times while setting
num_steps=1 in order to obtain the complete trajectory. However, if not
necessary, we recommend setting num_steps=L to save the unnecessary
computation of intermediate momentum variables.
The potential energy for Bayesian inference with HMC using neural
networks.
When obtaining samples from the posterior parameter distribution of a neural
network via HMC, a potential energy function has to be specified that allows
evaluating the negative log-posterior up to a constant. We consider a neural
network with parameters which encodes a likelihood function
for an input . In addition, a prior
needs to be specified. Given a dataset
consisting of inputs and targets , we can
specify the potential energy as (note, here )
where the first term corresponds to the negative log-likelihood (NLL). The
precise way of computing the NLL depends on which kind of likelihood
interpretation is forced onto the network (cf. argument nll_type).
A list of outputs of the hypernetwork. Each list entry
must have the output shape as returned by the
hnets.hnet_interface.HyperNetInterface.forward() method of the
hnet. Note, this function doesn’t detach targets. If desired,
that should be done before calling this function.
The current direction of weight change for the
internal (unconditional) weights of the hypernetwork evaluated on
the task-specific loss, i.e., the weight change that would be
applied to the unconditional parameters . This
regularizer aims to modify this direction, such that the hypernet
output for embeddings of previous tasks remains unaffected.
Note, this function does not detach dTheta. It is up to the
user to decide whether dTheta should be a constant vector or
might depend on parameters of the hypernet.
Also see utils.optim_step.calc_delta_theta().
dTembs (list, optional) – The current direction of weight change for the
task embeddings of all tasks that have been learned already.
See dTheta for details.
mnet – Instance of the main network. Has to be provided if
inds_of_out_heads are specified.
inds_of_out_heads –
(list, optional): List of lists of integers, denoting
which output neurons of the main network are used for predictions of
the corresponding previous tasks.
This will ensure that only weights of output neurons involved in
solving a task are regularized.
If provided, the method
mnets.mnet_interface.MainNetInterface.get_output_weight_maskofthemainnetwork``mnet`() is used to determine which hypernetwork
outputs require regularization.
fisher_estimates (list, optional) – A list of list of tensors, containing
estimates of the Fisher Information matrix for each weight
tensor in the main network and each task.
Note, that len(fisher_estimates)==task_id.
The Fisher estimates are used as importance weights for single
weights when computing the regularizer.
prev_theta (list, optional) – If given, prev_task_embs but not
targets has to be specified. prev_theta is expected to be
the internal unconditional weights prior to learning
the current task. Hence, it can be used to compute the targets on
the fly (which is more memory efficient (constant memory), but more
computationally demanding).
The computed targets will be detached from the computational graph.
Independent of the current hypernet mode, the targets are computed
in eval mode.
prev_task_embs (list, optional) – If given, prev_theta but not
targets has to be specified. prev_task_embs are the task
embeddings (conditional parameters) of the hypernetwork.
See docstring of prev_theta for more details.
If specified, only a random subset of
previous tasks is regularized. If the given number is bigger than
the number of previous tasks, all previous tasks are regularized.
Note
A batch_size smaller or equal to zero will be ignored
rather than throwing an error.
reg_scaling (list, optional) – If specified, the regulariation terms for
the different tasks are scaled arcording to the entries of this
list.
For all , compute the output of the
hypernetwork. This output will be detached from the graph before being added
to the return list of this function.
Note, if these targets don’t change during training, it would be more memory
efficient to store the weights of the hypernetwork (which
is a fixed amount of memory compared to the variable number of tasks).
Though, it is more computationally expensive to recompute
for all everytime the
target is needed.
Note, this function sets the hypernet temporarily in eval mode. No gradients
are computed.
hnet – An instance of the hypernetwork before learning a new task
(i.e., the hypernetwork has the weights necessary
to compute the targets).
Returns:
An empty list, if task_id is 0. Otherwise, a list of
task_id-1 targets. These targets can be passed to the function
calc_fix_target_reg() while training on the new task.
The module utils.init_utils contains helper functions that might be
useful for initialization of weights. The functions are somewhat complementary
to what is already provided in the PyTorch module torch.nn.init.
Initialize the given weight tensor with Xavier fan-in init.
Unfortunately, torch.nn.init.xavier_uniform_() doesn’t give
us the choice to use fan-in init (always uses the harmonic mean).
Therefore, we provide our own implementation.
Parameters:
tensor (torch.Tensor) – Weight tensor that will be modified
(initialized) in-place.
This module implements a biologically-plausible version of a convolutional layer
that does not use weight-sharing. Such a convnet is termed “locally-connected
network” in:
Implementation of a locally-connected 2D convolutional layer.
Since this implementation of a convolutional layer doesn’t use weight-
sharing, it will have more parameters than a conventional convolutional
layer such as torch.nn.Conv2d.
For example, consider a convolutional layer with kernel size [K,K],
C_in input channels and C_out output channels, that has an output
feature map size of [H,W]. Each receptive field [2] will have its
own weights, a parameter tensor of size KxK. Thus, in total the layer
will have C_out*C_in*H*W*K*K weights compared to
C_out*C_in*K*K weights that a conventional
torch.nn.Conv2d would have.
Consider the -th input feature map
(), the -th output feature map
() and the pixel with
coordinates in the -th output feature map
( and ).
We denote the filter weights of this pixel connecting to the -th
input feature map by .
The corresponding receptive field inside that is used to
compute pixel is denoted by
.
The bias weights for feature map are denoted by
, with a scalar weight for pixel
.
Using this notation, the computation of this layer can be described by the
following formula
where is the unary operator that computes the sum
of all elements in a matrix, denotes the Hadamard product
and denotes the Frobenius inner
product, which computes the sum of the entries of the Hadamard product
between real-valued matrices.
Implementation details
Let denote the batch size. We can use the function
torch.nn.functional.unfold() to split our input, which is of shape
[N,C_in,H_in,W_in], into receptive fields F_hat of dimension
[N,C_in*K*K,H*W]. The receptive field
would then correspond to F_hat[:,i*K*K:(i+1)*K*K,y*H+x],
assuming that indices now start at 0 and not at 1.
In addition, we have a weight tensor W of shape
[C_out,C_in*K*K,H*W].
Now, we can compute the element-wise product of receptive fields and their
filters by introducing a slack dimension into the shape of F_hat (i.e.,
[N,1,C_in*K*K,H*W]) and by using broadcasting. F_hat*W
will result into a tensor of shape [N,C_out,C_in*K*K,H*W].
By summing over the third dimension dim=2 and reshaping the output we
retrieve the result of our local convolutional layer.
Parameters:
in_channels (int) – Number of channels in the input image.
out_channels (int) – Number of channels produced by the convolution.
in_height (int) – Height of the input feature maps, assuming that input
feature maps have shape [C_in,H,W] (omitting the batch
dimension). This argument is necessary to compute the size of
output feature maps, as we need a filter for each pixel in each
output feature map.
x – The input images of shape [N,C_in,H_in,W_in], where N
denotes the batch size..
weights – Weights that should be used instead of the internally
maintained once (determined by attribute weights). Note,
if no_weights was True in the constructor, then this
parameter is mandatory.
A list of list of integers. Each list represents the shape of a
parameter tensor. Note, this attribute is independent of the attribute
weights, it always comprises the shapes of all weight tensors as
if the network would be stand-alone (i.e., no weights being passed to
the forward() method).
Configure the logger that should be used by all modules in this
package.
This method sets up a logger, such that all messages are written to console
and to an extra logging file. Both outputs will be the same, except that
a message logged to file contains the module name, where the message comes
from.
The implementation is based on an earlier implementation of a function I
used in another project:
Initialize the weights and biases of a linear or (transpose) conv layer.
Note, the implementation is based on the method “reset_parameters()”,
that defines the original PyTorch initialization for a linear or
convolutional layer, resp. The implementations can be found here:
If writing a figure to tensorboard via “add_figure” it might change the
canvas, such that our backend doesn’t allow to show the figure anymore.
This method will generate a new canvas and replace the old one of the
given figure.
Parameters:
fig – The figure to be shown.
close – Whether the figure should be closed after it has been shown.
Helper function to convert a string which is a list of comma separated
floats into an actual list of floats.
Parameters:
str_arg – String containing list of comma-separated floats. For
convenience reasons, we allow the user to also pass single float
that a put into a list of length 1 by this function.
Helper function to convert a string which is a list of comma separated
integers into an actual list of integers.
Parameters:
str_arg – String containing list of comma-separated ints. For convenience
reasons, we allow the user to also pass single integers that a put
into a list of length 1 by this function.
PyTorch optimizers don’t provide the ability to get a lookahead of the change to
the parameters applied by the torch.optim.Optimizer.step() method.
Therefore, this module copies step() functions from some optimizers, but
without applying the weight change and without making changes to the internal
state of an optimizer, such that the user can get the change of parameters that
would be executed by the optimizer.
detach_dp – Whether gradients are detached from the computational
graph. Note, False only makes sense if
func:torch.autograd.backward was called with the argument
create_graph set to True.
Returns:
A list of gradient changes d_p that would be applied by this
optimizer to all parameters when calling torch.optim.Adam.step().
Note, by default, gradients are detached from the computational graph.
Parameters:
optimizer – The optimizer that will be used to change .
use_sgd_change – If True, then we won’t calculate the actual step
done by the current optimizer, but the one that would be done by a
simple SGD optimizer.
lr – Has to be specified if use_sgd_change is True. The
learning rate if the optimizer.
detach_dt – Whether should be detached from the
computational graph. Note, in order to backprop through
, you have to call
torch.autograd.backward() with create_graph set to
True before calling this method.
detach_dp – Whether gradients are detached from the computational
graph. Note, False only makes sense if
func:torch.autograd.backward was called with the argument
create_graph set to True.
Returns:
A list of gradient changes d_p that would be applied by this
optimizer to all parameters when calling
torch.optim.RMSprop.step().
detach_dp – Whether gradients are detached from the computational
graph. Note, False only makes sense if
func:torch.autograd.backward was called with the argument
create_graph set to True.
Returns:
A list of gradient changes d_p that would be applied by this
optimizer to all parameters when calling torch.optim.SGD.step().
Self-Attention Layer with weights maintained separately. Hence, this
class should have the exact same behavior as “SelfAttnLayer” but the weights
are maintained independent of the preimplemented PyTorch modules, which
allows more flexibility (e.g., generating weights by a hypernet or modifying
weights easily).
The goal is to capture global correlations in convolutional networks (such
as generators and discriminators in GANs).
Initialize self-attention layer.
Parameters:
in_dim – Number of input channels (C).
use_spectral_norm – Enable spectral normalization for all 1x1 conv.
layers.
no_weights – If set to True, no trainable parameters will be
constructed, i.e., weights are assumed to be produced ad-hoc
by a hypernetwork and passed to the forward function.
init_weights (optional) – This option is for convinience reasons.
The option expects a list of parameter values that are used to
initialize the network weights. As such, it provides a
convinient way of initializing a network with a weight draw
produced by the hypernetwork.
See attribute “weight_shapes” for the format in which parameters
should be passed.
Compute and apply attention map to mix global information into local
features.
Parameters:
x – Input feature maps (shape: B x C x W x H).
ret_attention (optional) – If the attention map should be returned
as an additional return value.
weights – List of weight tensors, that are used as layer parameters.
If “no_weights” was set in the constructor, then this parameter
is mandatory.
Note, when provided, internal parameters are not used.
dWeights – List of weight tensors, that are added to “weights” (the
internal list of parameters or the one given via the option
“weights”), when computing the output of this network.
Returns:
Tuple (if ret_attention is True) containing:
out: gamma * (self-)attention features + input features.
attention: Attention map, shape: B X N X N (N = W * H).
This function is called after an optimizer update step has been performed.
It will perform an update of the internal running variable :math:omega`
using the current parameter values, the checkpointed parameter values
before the optimizer step (, see function
si_pre_optim_step()) and the negative gradients accumulated in the
grad variables of the parameters.
One may pass the parameter update step directly.
In this case. the difference between the current parameter values
and the previous ones will not be
computed.
Note
One may use the functions provided in module
utils.optim_step to calculate delta_params
Note
When this option is used, it is not required to explicitly call
the optimizer its step function. Though, it is still
required that gradients are computed and accumulated in the
grad variables of the parameters in params.
Note
This option is particularly interesting if importances should
only be estimated wrt to a part of the total loss function,
e.g., the task-specific part, ignoring other parts of the loss
(e.g., regularizers).
Prepare SI importance estimate before running the optimizer step.
This function has to be called before running the optimizer step in order
to checkpoint .
Note
When this function is called the first time (for the first task), the
given parameters will also be checkpointed as the initial weights,
which are required to normalize importances :math:Omega` after
training.
Parameters:
net (torch.nn.Module) – A network required to store buffers (i.e., the
running variables that SI needs to keep track of).
params (list) – A list of parameter tensors. For each parameter tensor
in this list that requires_grad the importances will be
measured.
params_name (str, optional) – In case SI should be performed for
multiple parameter groups params, one has to assign names to
each group via this option.
A prefix of the config names. It might be, that
the config names used in this method are prefixed, since several
main networks should be generated (e.g., cprefix='gen_' or
'dis_' when training a GAN).
Also see docstring of parameter prefix in function
utils.cli_args.main_net_args().
no_weights (bool) – Whether the main network should be generated without
weights.
**mnet_kwargs – Additional keyword arguments that will be passed to the
main network constructor.
This function should be called at the beginning of a simulation script
(right after the command-line arguments have been parsed). The setup will
incorporate:
creating the output folder
initializing logger
making computation deterministic (depending on config)
net – The network, that should load the state dict saved in this
checkpoint.
device (optional) – The device currently used by the model. Can help to
speed up loading the checkpoint.
ret_performance_score – If True, the score associated with this
checkpoint will be returned as well. See argument
“performance_score” of method “save_ckecpoint”.
Returns:
The loaded checkpoint. Note, the state_dict is already applied to the
network. However, there might be other important dict elements.
ckpt_dict – A dict with mostly arbitrary content. Though, most important,
it needs to include the state dict and should also include
the current training iteration.
file_path –
Where to store the checkpoint. Note, the filepath should
not change. Instead, train_iter should be provided,
such that this method can handle the filenames by itself.
Note
The function currently assumes that within the same directory,
no checkpoint filenname is the prefix of another
checkpoint filename (e.g., if several networks are checkpointed
into the same directory).
performance_score – A score that expresses the performance of the
current network state, e.g., accuracy for a
classification task. This score is used to
maintain the list of kept checkpoints during
training.
train_iter (optional) – If given, it will be added to the filename.
Otherwise, existing checkpoints are simply overwritten.
max_ckpts_to_keep – The maximum number of checkpoints to
keep. This will use the performance score to determine the n-1
checkpoints not to be deleted (where n is the number of
checkpoints to keep). The current checkpoint will always be saved.
keep_cktp_every – If this option is not None,
then every n hours one checkpoint will be permanently saved, i.e.,
this checkpoint will not be maintained by ‘max_ckpts_to_keep’
anymore. The checkpoint to be kept will be the best one from the
time window that spans the last n hours.
timestamp (optional) – The timestamp of this checkpoint. If not given,
a current timestamp will be used. This option is useful when one
aims to synchronize checkpoint savings from multiple networks.
A collection of helper functions that should capture common functionalities
needed when working with PyTorch.
pgroup_ids (list, optional) – If passed, a list of integers of the same
length as params is expected. In this case, each integer states to
which parameter group the corresponding parameter in params
shall belong. Parameter groups may have different optimizer
settings. Therefore, options like lr, momentum,
weight_decay, adam_beta1 may be lists in this case that have
a length corresponding to the number of parameter groups.
Initialize the weights and biases of a linear or (transpose) conv layer.
Note, the implementation is based on the method “reset_parameters()”,
that defines the original PyTorch initialization for a linear or
convolutional layer, resp. The implementations can be found here:
Computes a multiplicative factor for the initial learning rate based
on the current epoch. This method can be used as argument
lr_lambda of class torch.optim.lr_scheduler.LambdaLR.