Source code for hypnettorch.utils.ewc_regularizer

#!/usr/bin/env python3
# Copyright 2019 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title           :utils/ewc_regularizer.py
# @author          :ch
# @contact         :henningc@ethz.ch
# @created         :05/07/2019
# @version         :1.0
# @python_version  :3.6.8
"""
Elastic Weight Consolidation
----------------------------

Implementation of EWC:
    https://arxiv.org/abs/1612.00796

Note, these implementation are based on the descriptions provided in:
    https://arxiv.org/abs/1809.10635

The code is inspired by the corresponding implementation:
    https://git.io/fjcnL
"""
import torch
from torch.nn import functional as F
from warnings import warn

from hypnettorch.mnets.mnet_interface import MainNetInterface

[docs]def compute_fisher(task_id, data, params, device, mnet, hnet=None, empirical_fisher=True, online=False, gamma=1., n_max=-1, regression=False, time_series=False, allowed_outputs=None, custom_forward=None, custom_nll=None, pass_ids=False, proper_scaling=False, prior_strength=None, regression_lvar=1., target_manipulator=None): r"""Compute estimates of the diagonal elements of the Fisher information matrix, as needed as importance-weights by elastic weight consolidation (EWC). The Fisher matrix for a conditional distribution :math:`p(y \mid \theta, x)` (i.e., the model likelihood for a model with parameters :math:`\theta`) is defined as follows at location :math:`x` .. math:: \mathcal{F}(x) &= \textrm{Var} \big[ \nabla_{\theta} \log p(y \mid \theta, x) \big] \\ &= \mathbb{E}_{p(y \mid \theta, x)} \big[ \nabla_{\theta}\log p(y \mid \theta, x) \nabla_{\theta}\log p(y \mid \theta, x)^T\big] In practice, we are often interested in the Fisher averaged over locations .. math:: \mathcal{F} = \mathbb{E}_{p(x)} [ \mathcal{F}(x) ] Since the model is trained, such that in-distribution the model likelihood :math:`p(y \mid \theta, x)` and the ground-truth likelihood :math:`p(y \mid x)` agree, people often refer to the empirical Fisher, which utilizes the dataset for computation and therewith doesn't require sampling from the model likelihood. Note, EWC anyway assumes that in-distribution :math:`p(y \mid \theta, x) = p(y \mid x)` in order to be able to replace the Hessian by the Fisher matrix. .. math:: \mathcal{F}_{emp} &= \mathbb{E}_{p(x,y)} \big[ \nabla_{\theta}\log p(y \mid \theta, x) \nabla_{\theta}\log p(y \mid \theta, x)^T\big] \\ &= \mathbb{E}_{p(x)} \Big[ \mathbb{E}_{p(y \mid x)} \big[ \nabla_{\theta}\log p(y \mid \theta, x) \nabla_{\theta}\log p(y \mid \theta, x)^T\big] \Big] \\ &\approx \frac{1}{|\mathcal{D}|} \sum_{(x_n, y_n) \sim \mathcal{D}} \big[ \nabla_{\theta}\log p(y_n \mid \theta, x_n) \nabla_{\theta}\log p(y_n \mid \theta, x_n)^T\big] \Big] Note: This method registers buffers in the given module (storing the current parameters and the estimate of the Fisher diagonal elements), i.e., the ``mnet`` if ``hnet`` is ``None``, otherwise the ``hnet``. Args: task_id: The ID of the current task, needed to store the computed tensors with a unique name. When ``hnet`` is given, it is used as input to the ``hnet`` forward method to select the current task embedding. data: A data handler. We will compute the Fisher estimate across the whole training set (except ``n_max`` is specified). params: A list of parameter tensors from the module of which we aim to compute the Fisher for. If ``hnet`` is given, then these are assumed to be the "theta" parameters, that we pass to the forward function of the hypernetwork. Otherwise, these are the "weights" passed to the forward method of the main network. Note, they might not be detached from their original parameters, because we use ``backward()`` on the computational graph to read out the ``.grad`` variable. Note, the order in which these parameters are passed to this method and the corresponding EWC loss function must not change, because the index within the "params" list will be used as unique identifier. device: Current PyTorch device. mnet: The main network. If ``hnet`` is ``None``, then ``params`` are assumed to belong to this network. The fisher estimate will be computed accordingly. Note, ``params`` might be the output of a task-conditioned hypernetwork, i.e., weights for a specific task. In this case, "online"-EWC doesn't make much sense, as we don't follow the Bayesian view of using the old task weights as prior for the current ones. Instead, we have a new set of weights for all tasks. hnet (optional): If given, ``params`` is assumed to correspond to the unconditional weights :math:`\theta` (which does not include, for instance, task embeddings) of the hypernetwork. In this case, the diagonal Fisher entries belong to weights of the hypernetwork. The Fisher will then be computed based on the probability :math:`p(y \mid x, \text{task\_id})`, where ``task_id`` is just a constant input (representing the corresponding conditional weights, e.g., task embedding) in addition to the training samples :math:`x`. empirical_fisher: If ``True``, we compute the Fisher based on training targets. online: If ``True``, then we use online EWC, hence, there is only one diagonal Fisher approximation and one target parameter value stored at the time, rather than for all previous tasks. gamma: The gamma parameter for online EWC, controlling the gradual decay of previous tasks. n_max (optional): If not ``-1``, this will be the maximum amount of samples considered for estimating the Fisher. regression: Whether the task at hand is a classification or regression task. If ``True``, a regression task is assumed. For simplicity, we assume the following probabilistic model :math:`p(y \mid x) = \mathcal{N}\big(f(x), I\big)` with :math:`I` being the identity matrix. In this case, the only term of the log probability that influence the gradient is the MSE: :math:`\log p(y \mid x) = \lVert f(x) - y \rVert^2 + \text{const}` time_series (bool): If ``True``, the output of the main network ``mnet`` is expected to be a time series. In particular, we assume that the output is a tensor of shape ``[S, N, F]``, where ``S`` is the length of the time series, ``N`` is the batch size and ``F`` is the size of each feature vector (e.g., in classification, ``F`` would be the number of classes). Let :math:`\mathbf{y} = (\mathbf{y}_1, \dots \mathbf{y}_S)` be the output of the main network. We denote the parameters ``params`` by :math:`\theta` and the input by :math:`\mathbf{x}` (which we do not consider as random). We use the following decomposition of the likelihood .. math:: p(\mathbf{y} \mid \theta; \mathbf{x}) = \prod_{i=1}^S p(\mathbf{y}_i \mid \mathbf{y}_1, \dots, \mathbf{y}_{i-1}, \theta; \mathbf{x}_i) **Classification:** If :math:`f(\mathbf{x}_i, \mathbf{h}_{i-1}, \theta)` denotes the output of the main network ``mnet`` for timestep :math:`i` (assuming :math:`\mathbf{h}_{i-1}` is the most recent hidden state), we assume .. math:: p(\mathbf{y}_i \mid \mathbf{y}_1, \dots, \mathbf{y}_{i-1}, \theta; \mathbf{x}_i) \equiv \text{softmax} \big( f(\mathbf{x}_i, \mathbf{h}_{i-1}, \theta) \big) Hence, we assume that we can write the negative log-likelihood (NLL) as follows given a label :math:`t \in [1, \dots, F]^S`: .. math:: \text{NLL} &= - \log p(Y = t \mid \theta; \mathbf{x}) \\ &= \sum_{i=1}^S - \text{softmax} \big( f(\mathbf{x}_i, \mathbf{h}_{i-1}, \theta)_{t_i} \big) \\ &= \sum_{i=1}^S \text{cross\_entropy} \big( f(\mathbf{x}_i, \mathbf{h}_{i-1}, \theta), t_i \big) Thus, we simply sum the cross-entropy losses per time-step to estimate the NLL, which we then backpropagate through in order to compute the diagonal Fisher elements. allowed_outputs (optional): A list of indices, indicating which output neurons of the main network should be taken into account when computing the log probability. If not specified, all output neurons are considered. custom_forward (optional): A function handle that can replace the default procedure of forwarding samples through the given network(s). The default forward procedure if ``hnet`` is ``None`` is .. code:: python Y = mnet.forward(X, weights=params) Otherwise, the default forward procedure is .. code:: python weights = hnet.forward(task_id, theta=params) Y = mnet.forward(X, weights=weights) The signature of this function should be as follows. - ``hnet`` is ``None``: :code:`@fun(mnet, params, X)` - ``hnet`` is not ``None``: :code:`@fun(mnet, hnet, task_id, params, X)` where :code:`X` denotes the input batch to the main network (usually consisting of a single sample). Example: Imagine a situation where the main network uses context- dependent modulation (cmp. :class:`utils.context_mod_layer.ContextModLayer`) and the parameters of these context-mod layers are produced by the hypernetwork ``hnet``, whereas the remaining weights of the main network ``mnet`` are maintained internally and passed as argument ``params`` to this method. In particular, we look at a main network that is an instance of class :class:`mnets.mlp.MLP`. The forward pass through this combination of networks should be handled as follows in order to compute the correct fisher matrix: .. code:: python def custom_forward(mnet, hnet, task_id, params, X): mod_weights = hnet.forward(task_id) weights = { 'mod_weights': mod_weights, 'internal_weights': params } Y = mnet.forward(X, weights=weights) return Y custom_nll (optional): A function handle that can replace the default procedure of computing the negative-log-likelihood (NLL), which is required to compute the Fisher. The signature of this function should be as follows: :code:`@fun(Y, T, data, allowed_outputs, empirical_fisher)` where ``Y`` are the outputs of the main network. Note, ``allowed_outputs`` have already been applied to ``Y``, if given. ``T`` is the target provided by the dataset ``data``, transformed as follows: .. code:: python T = data.output_to_torch_tensor(batch[1], device, mode='inference') The arguments ``data``, ``allowed_outputs`` and ``empirical_fisher`` are only passed for convinience (e.g., to apply simple sanity checks using assertions). The output of the function handle should be the NLL for the given sample. pass_ids (bool): If a ``custom_nll`` is used and this flag is ``True``, then the signature of the ``cutom_nll`` is expected to be: .. code:: python @fun(Y, T, data, allowed_outputs, empirical_fisher, batch_ids) where ``batch_ids`` are the unique identifiers as returned by option ``return_ids`` of method :meth:`data.dataset.Dataset.next_train_batch` corresponding to the provided samples. Example: In sequential datasets, target sequences ``T`` might be padded to the same length. Though, if the unpadded length should be used for NLL computation, then the ``custom_nll`` function needs the ability to request this information (sequence length) from ``data``. Also, the signatures of ``custom_forward`` are expected to be different. The signature of this function should be as follows. - ``hnet`` is ``None``: ``@fun(mnet, params, X, data, batch_ids)`` - ``hnet`` is not ``None``: ``@fun(mnet, hnet, task_id, params, X, data, batch_ids)`` proper_scaling (bool): The algorithm `Online EWC` is based on a Taylor approximation of the posterior that leads to the following estimate .. math:: \log p(\theta \mid \mathcal{D}_1, \cdots, \mathcal{D}_T) \approx \log p(\mathcal{D}_T \mid \theta) - \frac{1}{2}\sum_i \bigg( \sum_{t < T} N_t \mathcal{F}_{emp \hspace{1mm}t, i} + \frac{1}{\sigma_{prior}^2} \bigg) (\theta_i - \theta_{S, i}^*)^2 + \text{const} Due to the presentation of the algorithm in the paper and inspired by multiple publicly implementations, we approximate the regularization strength in practice via .. math:: \sum_{t < T} N_t \mathcal{F}_{emp \hspace{1mm}t, i} + \frac{1}{\sigma_{prior}^2} \approx \lambda \sum_{t < T} \mathcal{F}_{emp \hspace{1mm}t, i} where :math:`\lambda` is a hyperparameter. If this argument is ``True``, then the sum of Fisher matrices is properly weighted by the dataset size (independent of argument ``n_max``). prior_strength (float or list, optional): Either a scalar or a list of Tensors with the same shapes as ``params``. Only applies to `Online EWC`. One can specify an offset for all Fisher values, e.g., :math:`\frac{1}{\sigma_{prior}^2}`. See argument ``proper_scaling`` for details. regression_lvar (float): In regression, this refers to the variance of the likelihood. target_manipulator (func, optional): A function with signature .. code:: python T = target_manipulator(T) That may manipulate the targets coming from the dataset. """ assert isinstance(mnet, MainNetInterface) assert mnet.has_linear_out # TODO if hnet is not None: raise NotImplementedError() assert hnet is None or task_id is not None assert not online or (gamma >= 0. and gamma <= 1.) assert n_max == -1 or n_max > 0 if time_series and regression: raise NotImplementedError('Computing the Fisher for a recurrent ' + 'regression task is not yet implemented.') if not online: if proper_scaling: # Doesn't hurt, we can get rid of warning. warn('Argument "proper_scaling" is only well justified in ' + 'Online EWC.') if prior_strength is not None: # We have a separate Fisher per task for EWC. raise ValueError('Option "prior_strength" only applicable to ' + 'Online EWC.') n_samples = data.num_train_samples if n_max != -1: n_samples = min(n_samples, n_max) mnet_mode = mnet.training mnet.eval() if hnet is not None: hnet_mode = hnet.training hnet.eval() fisher = [] for ii, p in enumerate(params): if prior_strength is None: fisher.append(torch.zeros_like(p)) elif isinstance(prior_strength, (list, tuple)): assert len(prior_strength) == len(params) and \ prior_strength[ii].shape == p.shape fisher.append(prior_strength[ii].clone()) else: fisher.append(torch.ones_like(p) * prior_strength) assert p.requires_grad # Otherwise, we can't compute the Fisher. # Ensure, that we go through all training samples (note, that training # samples are always randomly shuffled when using "next_train_batch", but # we always go through the complete batch before reshuffling the samples.) # If `n_max` was specified, we always go through a different random # subsample of the training set. data.reset_batch_generator(train=True, test=False, val=False) # Since the PyTorch grad function accumulates gradients, we have to go # through single training samples. for s in range(n_samples): batch = data.next_train_batch(1, return_ids=pass_ids) X = data.input_to_torch_tensor(batch[0], device, mode='inference') T = data.output_to_torch_tensor(batch[1], device, mode='inference') if target_manipulator is not None: T = target_manipulator(T) if hnet is None: if custom_forward is None: Y = mnet.forward(X, weights=params) else: if pass_ids: Y = custom_forward(mnet, params, X, data, batch[2]) else: Y = custom_forward(mnet, params, X) else: if custom_forward is None: weights = hnet.forward(task_id, theta=params) Y = mnet.forward(X, weights=weights) else: if pass_ids: Y = custom_forward(mnet, hnet, task_id, params, X, data, batch[2]) else: Y = custom_forward(mnet, hnet, task_id, params, X) if not time_series: assert(len(Y.shape) == 2) else: assert(len(Y.shape) == 3) if allowed_outputs is not None: if not time_series: Y = Y[:, allowed_outputs] else: Y = Y[:, :, allowed_outputs] ### Compute negative log-likelihood. if custom_nll is not None: if pass_ids: nll = custom_nll(Y, T, data, allowed_outputs, empirical_fisher, batch[2]) else: nll = custom_nll(Y, T, data, allowed_outputs, empirical_fisher) elif regression: # Note, if regression, we don't have to modify the targets. # Thus, through "allowed_outputs" Y has been brought into the same # shape as T. if empirical_fisher: # The term that doesn't vanish in the gradient of the log # probability is the squared L2 norm between Y and T. nll = 0.5 / regression_lvar * (Y - T).pow(2).sum() else: raise NotImplementedError('Only empirical Fisher is ' + 'implemented so far!') else: # Note, we assume the output of the main network is linear, such # that we can compute the log probabilities by applying the log- # softmax to these outputs. assert data.classification and len(data.out_shape) == 1 if allowed_outputs is not None: assert target_manipulator is not None or \ len(allowed_outputs) == data.num_classes assert Y.shape[2 if time_series else 1] == len(allowed_outputs) # Targets might be labels or one-hot encodings. if data.is_one_hot: if time_series: assert(len(T.shape) == 3 and T.shape[2] == Y.shape[2]) T = torch.argmax(T, 2) else: # Note, this function processes always one sample at a time # (batchsize=1), so `T` contains a single number. T = torch.argmax(T) # Important, distinguish between empiricial and normal fisher! if empirical_fisher: if not time_series: # For classification, only the loss associated with the # target unit is taken into consideration. nll = F.nll_loss(F.log_softmax(Y, dim=1), torch.tensor([T]).to(device)) else: ll = F.log_softmax(Y, dim=2) # log likelihood for all labels # We need to swap dimenstions from [S, N, F] to [S, F, N]. # See documentation of method `nll_loss`. ll = ll.permute(0, 2, 1) nll = F.nll_loss(ll, T, reduction='none') # Mean across batch dimension, but sum across time-series # dimension. assert(len(nll.shape) == 2) nll = nll.mean(dim=1).sum() else: raise NotImplementedError('Only empirical Fisher is ' + 'implemented so far!') ### Compute gradient of negative log likelihood to estimate Fisher mnet.zero_grad() if hnet is not None: hnet.zero_grad() torch.autograd.backward(nll, retain_graph=False, create_graph=False) for i, p in enumerate(params): fisher[i] += torch.pow(p.grad.detach(), 2) # This version would not require use to call zero_grad and hence, we # wouldn't fiddle with internal variables, but it would require us to # loop over tensors and retain the graph in between. #for p in params: # g = torch.autograd.grad(nll, p, grad_outputs=None, # retain_graph=True, create_graph=False, # only_inputs=True)[0] # fisher[i] += torch.pow(g.detach(), 2) for i in range(len(params)): if not proper_scaling: fisher[i] /= n_samples elif n_samples != data.num_train_samples: fisher[i] *= data.num_train_samples / n_samples ### Register buffers to store current task weights as well as the Fisher. net = mnet if hnet is not None: net = hnet for i, p in enumerate(params): buff_w_name, buff_f_name = _ewc_buffer_names(task_id, i, online) # We use registered buffers rather than class members to ensure that # these variables appear in the state_dict and are thus written into # checkpoints. net.register_buffer(buff_w_name, p.detach().clone()) # In the "online" case, the old fisher estimate buffer will be # overwritten. if online and task_id > 0: prev_fisher_est = getattr(net, buff_f_name) # Decay of previous fisher. fisher[i] += gamma * prev_fisher_est net.register_buffer(buff_f_name, fisher[i].detach().clone()) mnet.train(mode=mnet_mode) if hnet is not None: hnet.train(mode=hnet_mode)
[docs]def ewc_regularizer(task_id, params, mnet, hnet=None, online=False, gamma=1.): """Compute the EWC regularizer, that can be added to the remaining loss. Note, the hyperparameter, that trades-off the regularization strength is not yet multiplied by the loss. This loss assumes an appropriate use of the method "compute_fisher". Note, for the current task "compute_fisher" has to be called after calling this method. If `online` is False, this method implements the loss proposed in eq. (3) in [EWC2017]_, except for the missing hyperparameter `lambda`. The online EWC implementation follows eq. (8) from [OnEWC2018]_ (note, that lambda does not appear in this equation, but it was used in their experiments). .. [EWC2017] https://arxiv.org/abs/1612.00796 .. [OnEWC2018] https://arxiv.org/abs/1805.06370 Args: (....): See docstring of method :func:`compute_fisher`. Returns: EWC regularizer. """ assert(task_id > 0) net = mnet if hnet is not None: net = hnet ewc_reg = 0 num_prev_tasks = 1 if online else task_id for t in range(num_prev_tasks): for i, p in enumerate(params): buff_w_name, buff_f_name = _ewc_buffer_names(t, i, online) prev_weights = getattr(net, buff_w_name) fisher_est = getattr(net, buff_f_name) # Note, since we haven't called "compute_fisher" yet, the forgetting # scalar has not been multiplied yet. if online: fisher_est *= gamma ewc_reg += (fisher_est * (p - prev_weights).pow(2)).sum() # Note, the loss proposed in the original paper is not normalized by the # number of tasks #return ewc_reg / num_prev_tasks / 2. return ewc_reg / 2.
def _ewc_buffer_names(task_id, param_id, online): """The names of the buffers used to store EWC variables. Args: task_id: ID of task (only used of `online` is False). param_id: Identifier of parameter tensor. online: Whether the online EWC algorithm is used. Returns: (tuple): Tuple containing: - **weight_buffer_name** - **fisher_estimate_buffer_name** """ task_ident = '' if online else '_task_%d' % task_id weight_name = 'ewc_prev{}_weights_{}'.format(task_ident, param_id) fisher_name = 'ewc_fisher_estimate{}_weights_{}'.format(task_ident, param_id) return weight_name, fisher_name
[docs]def context_mod_forward(mod_weights=None): """Create a custom forward function for function :func:`compute_fisher`. See argument ``custom_forward`` of function :func:`compute_fisher` for more details. This is a helper method to quickly retrieve a function handle that manages the forward pass for a context-modulated main network. We assume that the interface of the main network is similar to the one of :meth:`mnets.mlp.MLP.forward`. Args: mod_weights (optional): If provided, it is assumed that :func:`compute_fisher` is called with ``hnet`` set to ``None``. Hence, the returned function handle will have the given context-modulation pattern hard-coded. If left unspecified, it is assumed that a ``hnet`` is passed to :func:`compute_fisher` and that this ``hnet`` computes only the parameters of all context-mod layers. Returns: A function handle. """ def hnet_forward(mnet, hnet, task_id, params, X): mod_weights = hnet.forward(task_id) weights = { 'mod_weights': mod_weights, 'internal_weights': params } Y = mnet.forward(X, weights=weights) return Y def mnet_only_forward(mnet, params, X): weights = { 'mod_weights': mod_weights, 'internal_weights': params } Y = mnet.forward(X, weights=weights) return Y if mod_weights is None: return hnet_forward else: return mnet_only_forward
if __name__ == '__main__': pass