Source code for hypnettorch.utils.context_mod_layer

#!/usr/bin/env python3
# Copyright 2019 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title          :utils/context_mod_layer.py
# @author         :ch
# @contact        :henningc@ethz.ch
# @created        :10/18/2019
# @version        :1.0
# @python_version :3.6.8
"""
Context-modulation layer
------------------------

This module should represent a special gain-modulation layer that can modulate
neural computation based on an external context.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from warnings import warn

[docs]class ContextModLayer(nn.Module): r"""Implementation of a layer that can apply context-dependent modulation on the level of neuronal computation. The layer consists of two parameter vectors: gains :math:`\mathbf{g}` and shifts :math:`\mathbf{s}`, whereas gains represent a multiplicative modulation of input activations and shifts an additive modulation, respectively. Note, the weight vectors :math:`\mathbf{g}` and :math:`\mathbf{s}` might also be passed to the :meth:`forward` method, where one may pass a separate set of parameters for each sample in the input batch. Example: Assume that a :class:`ContextModLayer` is applied between a linear (fully-connected) layer :math:`\mathbf{y} \equiv W \mathbf{x} + \mathbf{b}` with input :math:`\mathbf{x}` and a nonlinear activation function :math:`z \equiv \sigma(y)`. The layer-computation in such a case will become .. math:: \sigma \big( (W \mathbf{x} + \mathbf{b}) \odot \mathbf{g} + \ \mathbf{s} \big) Args: num_features (int or tuple): Number of units in the layer (size of parameter vectors :math:`\mathbf{g}` and :math:`\mathbf{s}`). In case a ``tuple`` of integers is provided, the gain :math:`\mathbf{g}` and shift :math:`\mathbf{s}` parameters will become multidimensional tensors with the shape being prescribed by ``num_features``. Please note the `broadcasting rules`_ as :math:`\mathbf{g}` and :math:`\mathbf{s}` are simply multiplied or added to the input. Example: Consider the output of a convolutional layer with output shape ``[B,C,W,H]``. In case there should be a scalar gain and shift per feature map, ``num_features`` could be ``[C,1,1]`` or ``[1,C,1,1]`` (one might also pass a shape ``[B,C,1,1]`` to the :meth:`forward` method to apply separate shifts and gains per sample in the batch). Alternatively, one might want to provide shift and gain per output unit, i.e., ``num_features`` should be ``[C,W,H]``. Note, that due to weight sharing, all output activities within a feature map are computed using the same weights, which is why it is common practice to share shifts and gains within a feature map (e.g., in Spatial Batch-Normalization). no_weights (bool): If ``True``, the layer will have no trainable weights (:math:`\mathbf{g}` and :math:`\mathbf{s}`). Hence, weights are expected to be passed to the :meth:`forward` method. no_gains (bool): If ``True``, no gain parameters :math:`\mathbf{g}` will be modulating the input activity. .. note:: Arguments ``no_gains`` and ``no_shifts`` might not be activated simultaneously! no_shifts (bool): If ``True``, no shift parameters :math:`\mathbf{s}` will be modulating the input activity. apply_gain_offset (bool, optional): If activated, this option will apply a constant offset of 1 to all gains, i.e., the computation becomes .. math:: \sigma \big( (W \mathbf{x} + \mathbf{b}) \odot \ (1 + \mathbf{g}) + \mathbf{s} \big) When could that be useful? In case the gains and shifts are generated by the same hypernetwork, a meaningful initialization might be difficult to achieve (e.g., such that gains are close to 1 and shifts are close to 0 at the beginning). Therefore, one might initialize the hypernetwork such that all outputs are close to zero at the beginning and the constant shift ensures that meaningful gains are applied. apply_gain_softplus (bool, optional): If activated, this option will enforce poitive gain modulation by sending the gain weights :math:`\mathbf{g}` through a softplus function (scaled by :math:`s`, see ``softplus_scale``). .. math:: \mathbf{g} = \frac{1}{s} \log(1+\exp(\mathbf{g} \cdot s)) softplus_scale (float): If option ``apply_gain_softplus`` is ``True``, then this will determine the sclae of the softplus function. .. _broadcasting rules: https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-\ semantics """ def __init__(self, num_features, no_weights=False, no_gains=False, no_shifts=False, apply_gain_offset=False, apply_gain_softplus=False, softplus_scale=1.): super(ContextModLayer, self).__init__() assert(isinstance(num_features, (int, list, tuple))) if not isinstance(num_features, int): for nf in num_features: assert(isinstance(nf, int)) else: num_features = [num_features] assert(not no_gains or not no_shifts) self._num_features = num_features self._no_weights = no_weights self._no_gains = no_gains self._no_shifts = no_shifts self._apply_gain_offset = apply_gain_offset self._apply_gain_softplus = apply_gain_softplus self._sps = softplus_scale if apply_gain_offset and apply_gain_softplus: raise ValueError('Options "apply_gain_offset" and ' + '"apply_gain_softplus" are not compatible.') self._weights = None self._param_shapes = [num_features] * (1 if no_gains or no_shifts \ else 2) self._param_shapes_meta = ([] if no_gains else ['gain']) + \ ([] if no_shifts else ['shift']) self.register_buffer('_num_ckpts', torch.tensor(0, dtype=torch.long)) if not no_weights: self._weights = nn.ParameterList() if not no_gains: self.register_parameter('gain', nn.Parameter( \ torch.Tensor(*num_features), requires_grad=True)) self._weights.append(self.gain) if apply_gain_offset: nn.init.zeros_(self.gain) else: nn.init.ones_(self.gain) else: self.register_parameter('gain', None) if not no_shifts: self.register_parameter('shift', nn.Parameter( \ torch.Tensor(*num_features), requires_grad=True)) self._weights.append(self.shift) nn.init.zeros_(self.shift) else: self.register_parameter('shift', None) @property def weights(self): """A list of all internal weights of this layer. If all weights are assumed to be generated externally, then this attribute will be ``None``. :type: torch.nn.ParameterList or None """ return self._weights @property def param_shapes(self): """A list of list of integers. Each list represents the shape of a parameter tensor. Note, this attribute is independent of the attribute :attr:`weights`, it always comprises the shapes of all weight tensors as if the network would be stand- alone (i.e., no weights being passed to the :meth:`forward` method). .. note:: The weights passed to the :meth:`forward` method might deviate from these shapes, as we allow passing a distinct set of parameters per sample in the input batch. :type: list """ return self._param_shapes @property def param_shapes_meta(self): r"""List of strings. Each entry represents the meaning of the corresponding entry in :attr:`param_shapes`. The following keywords are possible: - ``'gain'``: The corresponding shape in :attr:`param_shapes` denotes the gain :math:`\mathbf{g}` parameter. - ``'shift'``: The corresponding shape in :attr:`param_shapes` denotes the shift :math:`\mathbf{s}` parameter. :type: list """ return self._param_shapes_meta @property def num_ckpts(self): """The number of existing weight checkpoints (i.e., how often the method :meth:`checkpoint_weights` was called). :type: int """ return self._num_ckpts @property def gain_offset_applied(self): r"""Whether constructor argument ``apply_gain_offset`` was activated. Thus, whether an offset for the gain :math:`\mathbf{g}` is applied. :type: bool """ return self._apply_gain_offset @property def gain_softplus_applied(self): r"""Whether constructor argument ``apply_gain_softplus`` was activated. Thus, whether a softplus function for the gain :math:`\mathbf{g}` is applied. :type: bool """ return self._apply_gain_softplus @property def has_gains(self): r"""Is ``True`` if ``no_gains`` was not set in the constructor. Thus, whether gains :math:`\mathbf{g}` are part of the computation of this layer. :type: bool """ return not self._no_gains @property def has_shifts(self): r"""Is ``True`` if ``no_shifts`` was not set in the constructor. Thus, whether shifts :math:`\mathbf{s}` are part of the computation of this layer. :type: bool """ return not self._no_shifts
[docs] def forward(self, x, weights=None, ckpt_id=None, bs_dim=0): """Apply context-dependent gain modulation. Computes :math:`\mathbf{x} \odot \mathbf{g} + \mathbf{s}`, where :math:`\mathbf{x}` denotes the input activity ``x``. Args: x: The input activity. weights: Weights that should be used instead of the internally maintained once (determined by attribute :attr:`weights`). Note, if ``no_weights`` was ``True`` in the constructor, then this parameter is mandatory. Usually, the shape of the passed weights should follow the attribute :attr:`param_shapes`, which is a tuple of shapes ``[[num_features], [num_features]]`` (at least for linear layers, see docstring of argument ``num_features`` in the constructor for more details). However, one may also specify a seperate set of context-mod parameters per input sample. Assume ``x`` has shape ``[num_samples, num_features]``. Then ``weights`` may have the shape ``[[num_samples, num_features], [num_samples, num_features]]``. ckpt_id (int): This argument can be set in case a checkpointed set of weights should be used to compute the forward pass (see method :meth:`checkpoint_weights`). .. note:: This argument is ignored if ``weights`` is not ``None``. bs_dim (int): Batch size dimension in input tensor ``x``. Returns: The modulated input activity. """ if self._no_weights and weights is None: raise ValueError('Layer was generated without weights. ' + 'Hence, "weights" option may not be None.') if weights is not None and ckpt_id is not None: warn('Context-mod layer received weights as well as the request ' + 'to load checkpointed weights. The request to load ' + 'checkpointed weights will be ignored.') # FIXME I haven't thoroughly checked whether broadcasting works # correctly if `bs_dim != 0`. batch_size = x.shape[bs_dim] if weights is None: gain, shift = self.get_weights(ckpt_id=ckpt_id) if self._no_gains: weights = [shift] elif self._no_shifts: weights = [gain] else: weights = [gain, shift] else: assert(len(weights) in [1, 2]) nfl = len(self._num_features) nb = len(x.shape) for p in weights: # Note, the user might add the batch dimension when providing # gains and shifts, such that there are separate gain and shift # parameters per sample in the batch. assert(len(p.shape) in [nfl, nb]) if len(p.shape) == nfl: assert(np.all(np.equal(p.shape, self._num_features))) else: # One set of parameters per sample in the batch. assert(p.shape[0] == batch_size and \ np.all(np.equal(p.shape[1:], self._num_features))) gain = None shift = None if self._no_gains: assert(len(weights) == 1) shift = weights[0] elif self._no_shifts: assert(len(weights) == 1) gain = weights[0] else: assert(len(weights) == 2) gain = weights[0] shift = weights[1] if gain is not None: x = x.mul(self.preprocess_gain(gain)) if shift is not None: x = x.add(shift) return x
[docs] def preprocess_gain(self, gain): r"""Obtains gains :math:`\mathbf{g}` used for mudulation. Depending on the user configuration, gains might be preprocessed before applied for context-modulation (e.g., see attributes :attr:`gain_offset_applied` or :attr:`gain_softplus_applied`). This method transforms raw gains such that they can be applied to the network activation. Note: This method is called by the :meth:`forward` to transform given gains. Args: gain (torch.Tensor): A gain tensor. Returns: (torch.Tensor): The transformed gains. """ if self._apply_gain_softplus: gain = 1. / self._sps * F.softplus(gain * self._sps) elif self._apply_gain_offset: gain = gain + 1. return gain
[docs] def checkpoint_weights(self, device=None, no_reinit=False): """Checkpoint and reinit the current weights. Buffers for a new checkpoint will be registered and the current weights will be copied into them. Additionally, the current weights will be reinitialized (gains to 1 and shifts to 0). Calling this function will also increment the attribute :attr:`num_ckpts`. Note: This method uses the method :meth:`torch.nn.Module.register_buffer` rather than the method :meth:`torch.nn.Module.register_parameter` to create checkpoints. The reason is, that we don't want the checkpoints to appear as trainable weights (when calling :meth:`torch.nn.Module.parameters`). However, that means that training on checkpointed weights cannot be continued unless they are copied back into an actual :class:`torch.nn.Parameter` object. Args: device (optional): If not provided, the newly created checkpoint will be moved to the device of the current weights. no_reinit (bool): If ``True``, the actual :attr:`weights` will not be reinitialized. """ assert(not self._no_weights) if device is None: if self.gain is not None: device = self.gain.device else: device = self.shift.device gname, sname = self._weight_names(self._num_ckpts) self._num_ckpts += 1 if not self._no_gains: self.register_buffer(gname, torch.empty_like(self.gain, device=device)) getattr(self, gname).data = self.gain.detach().clone() if not no_reinit: if self._apply_gain_offset: nn.init.zeros_(self.gain) else: nn.init.ones_(self.gain) else: self.register_buffer(gname, None) if not self._no_shifts: self.register_buffer(sname, torch.empty_like(self.shift, device=device)) getattr(self, sname).data = self.shift.detach().clone() if not no_reinit: nn.init.zeros_(self.shift) else: self.register_buffer(gname, None)
[docs] def get_weights(self, ckpt_id=None): """Get the current (or a set of checkpointed) weights of this context- mod layer. Args: ckpt_id (optional): ID of checkpoint. If not provided, the current set of weights is returned. If :code:`ckpt_id == self.num_ckpts`, then this method also returns the current weights, as the checkpoint has not been created yet. Returns: (tuple): Tuple containing: - **gain**: Is ``None`` if layer has no gains. - **shift**: Is ``None`` if layer has no shifts. """ if ckpt_id is None or ckpt_id == self.num_ckpts: return self.gain, self.shift assert(ckpt_id >= 0 and ckpt_id < self.num_ckpts) gname, sname = self._weight_names(ckpt_id) gain = getattr(self, gname) shift = getattr(self, sname) return gain, shift
def _weight_names(self, ckpt_id): """Get the buffer names for checkpointed gain and shift weights depending on the ``ckpt_id``, i.e., the ID of the checkpoint. Args: ckpt_id: ID of weight checkpoint. Returns: (tuple): Tuple containing: - **gain_name** - **shift_name** """ gain_name = 'gain_ckpt_%d' % ckpt_id shift_name = 'shift_ckpt_%d' % ckpt_id return gain_name, shift_name
[docs] def normal_init(self, std=1.): """Reinitialize internal weights using a normal distribution. Args: std (float): Standard deviation of init. """ if self._no_weights: raise ValueError('Method is not applicable to layers without ' + 'internally maintained weights.') if not self._no_gains: if self._apply_gain_offset: nn.init.normal_(self.gain, std=std) else: nn.init.normal_(self.gain, mean=1., std=std) if not self._no_shifts: nn.init.normal_(self.shift, std=std)
[docs] def uniform_init(self, width=1.): """Reinitialize internal weights using a uniform distribution. Args: width (float): The range of the uniform init will be determined as ``[mean-width, mean+width]``, where ``mean`` is 0 for shifts and 1 for gains. """ if self._no_weights: raise ValueError('Method is not applicable to layers without ' + 'internally maintained weights.') if not self._no_gains: if self._apply_gain_offset: nn.init.uniform_(self.gain, a=-width, b=width) else: nn.init.uniform_(self.gain, a=1.-width, b=1.+width) if not self._no_shifts: nn.init.uniform_(self.shift, a=-width, b=width)
[docs] def sparse_init(self, sparsity=.8): """Reinitialize internal weights sparsely. Gains will be initialized such that ``sparisity * 100`` percent of them will be 0, the remaining ones will be 1. Shifts are initialized to 0. Args: sparsity (float): A number between 0 and 1 determining the spasity level of gains. """ if self._no_weights: raise ValueError('Method is not applicable to layers without ' + 'internally maintained weights.') assert 0 <= sparsity <= 1 if not self._no_gains: num_zeros = int(self.gain.numel() * sparsity) inds = np.zeros(self.gain.numel(), dtype=bool) inds = inds.reshape(-1) inds[:num_zeros] = True np.random.shuffle(inds) inds = inds.reshape(*self.gain.shape) inds = torch.from_numpy(inds).to(self.gain.device) if self._apply_gain_offset: nn.init.zeros_(self.gain) self.gain.data[inds] = -1. else: nn.init.ones_(self.gain) self.gain.data[inds] = 0. if not self._no_shifts: nn.init.zeros_(self.shift)
if __name__ == '__main__': pass