Source code for hypnettorch.mnets.bio_conv_net

#!/usr/bin/env python3
# Copyright 2019 Christian Henning
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# @title          :mnets/
# @author         :ch
# @contact
# @created        :10/30/2019
# @version        :1.0
# @python_version :3.6.8
A bio-plausible convolutional network for CIFAR

The module :mod:`mnets.bio_conv_net` implements a simple biologically-plausible
network with convolutional and fully-connected layers. The bio-plausibility
arises through the usage of conv-layers without weight sharing, i.e., layers
from class :class:`utils.local_conv2d_layer.LocalConv2dLayer`. The network
specification has been taken from the following paper

    `Bartunov et al., "Assessing the Scalability of Biologically-Motivated Deep
    Learning Algorithms and Architectures", NeurIPS 2018.

in which this kind of network has been termed "locally-connected network".

In particular, we consider the network architecture specified in table 3 on page
13 for the CIFAR dataset.

.. autosummary::

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from hypnettorch.mnets.classifier_interface import Classifier
from hypnettorch.mnets.mnet_interface import MainNetInterface
from hypnettorch.utils.context_mod_layer import ContextModLayer
from hypnettorch.utils.local_conv2d_layer import LocalConv2dLayer
from hypnettorch.utils.torch_utils import init_params

[docs]class BioConvNet(Classifier): """Implementation of a locally-connected network for CIFAR. The network consists of 3 bio-plausible convolutional layers (using class :class:`utils.local_conv2d_layer.LocalConv2dLayer`) followed by two fully-connected layers. Assume conv layers are specified by the tuple ``(K x K, C, S, P)``, where ``K`` denotes the kernel size, ``C`` the number of channels, ``S`` the stride and ``P`` the padding. The network is defined as follows - Bio-conv layer (5 x 5, 64, 2, 0) - Bio-conv layer (5 x 5, 128, 2, 0) - Bio-conv layer (3 x 3, 256, 1, 1) - FC layer with 1024 outputs - FC layer with 10 outputs Note, the padding for the first two convolutional layers was not specified in the paper, so we just assumed it to be zero. The **network output will be linear**, so we do not apply the softmax inside the :meth:`forward` method. Note, the paper states that ``tanh`` was used in all networks as non-linearity. Therefore, we use this non-linearity too. Args: in_shape: The shape of an input sample. .. note:: We assume the Tensorflow format, where the last entry denotes the number of channels. num_classes: The number of output neurons. no_weights (bool): If set to ``True``, no trainable parameters will be constructed, i.e., weights are assumed to be produced ad-hoc by a hypernetwork and passed to the :meth:`forward` method. init_weights (optional): This option is for convinience reasons. The option expects a list of parameter values that are used to initialize the network weights. As such, it provides a convinient way of initializing a network with a weight draw produced by the hypernetwork. Note, internal weights (see :attr:`mnets.mnet_interface.MainNetInterface.weights`) will be affected by this argument only. use_context_mod (bool): Add context-dependent modulation layers :class:`utils.context_mod_layer.ContextModLayer` after the linear computation of each layer. context_mod_inputs (bool): Whether context-dependent modulation should also be applied to network intpus directly. I.e., assume :math:`\mathbf{x}` is the input to the network. Then the first network operation would be to modify the input via :math:`\mathbf{x} \cdot \mathbf{g} + \mathbf{s}` using context- dependent gain and shift parameters. Note: Argument applies only if ``use_context_mod`` is ``True``. no_last_layer_context_mod (bool): If ``True``, context-dependent modulation will not be applied to the output layer. Note: Argument applies only if ``use_context_mod`` is ``True``. context_mod_no_weights (bool): The weights of the context-mod layers (:class:`utils.context_mod_layer.ContextModLayer`) are treated independently of the option ``no_weights``. This argument can be used to decide whether the context-mod parameters (gains and shifts) are maintained internally or externally. Note: Check out argument ``weights`` of the :meth:`forward` method on how to correctly pass weights to the network that are externally maintained. context_mod_post_activation (bool): Apply context-mod layers after the activation function (``activation_fn``) in hidden layer rather than before, which is the default behavior. Note: This option only applies if ``use_context_mod`` is ``True``. Note: This option does not affect argument ``context_mod_inputs``. Note: Note, there is no non-linearity applied to the output layer, such that this argument has no effect there. context_mod_gain_offset (bool): Activates option ``apply_gain_offset`` of class :class:`utils.context_mod_layer.ContextModLayer` for all context-mod layers that will be instantiated. context_mod_gain_softplus (bool): Activates option ``apply_gain_softplus`` of class :class:`utils.context_mod_layer.ContextModLayer` for all context-mod layers that will be instantiated. context_mod_apply_pixel_wise (bool): If ``False``, the context-dependent modulation applies a scalar gain and shift to all feature maps in the output of a convolutional layer. When activating this option, the gain and shift will be a per-pixel parameter in all feature maps. To be more precise, consider the output of a convolutional layer of shape ``[C,H,W]``. If ``False``, there will be ``C`` gain and shift parameters for such a layer. Upon activating this option, the number of gain and shift parameters for such a layer will increase to ``C x H x W``. """ def __init__(self, in_shape=(32, 32, 3), num_classes=10, no_weights=False, init_weights=None, use_context_mod=False, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=False, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_gain_softplus=False, context_mod_apply_pixel_wise=False): super(BioConvNet, self).__init__(num_classes, True) assert(len(in_shape) == 3) # FIXME This assertion is not mandatory but a sanity check that the user # uses the Tensorflow layout. assert(in_shape[2] in [1, 3]) assert(init_weights is None or \ (not no_weights or not context_mod_no_weights)) self._in_shape = in_shape self._no_weights = no_weights self._use_context_mod = use_context_mod self._context_mod_inputs = context_mod_inputs self._no_last_layer_context_mod = no_last_layer_context_mod self._context_mod_no_weights = context_mod_no_weights self._context_mod_post_activation = context_mod_post_activation self._context_mod_gain_offset = context_mod_gain_offset self._context_mod_gain_softplus = context_mod_gain_softplus self._context_mod_apply_pixel_wise = context_mod_apply_pixel_wise self._has_bias = True self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True self._param_shapes = [] self._weights = None if no_weights and context_mod_no_weights \ else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not context_mod_no_weights else [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() # Shapes of output activities for context-modulation, if used. cm_shapes = [] # Output shape of all layers. if context_mod_inputs: cm_shapes.append([in_shape[2], *in_shape[:2]]) ### Define and initialize all conv and linear layers ### Bio-conv layers. H = in_shape[0] W = in_shape[1] C_in = in_shape[2] C = [64, 128, 256] K = [5, 5, 3] S = [2, 2, 1] P = [0, 0, 1] self._conv_layer = [] for i, C_out in enumerate(C): self._conv_layer.append(LocalConv2dLayer(C_in, C_out, H, W, K[i], stride=S[i], padding=P[i], no_weights=no_weights)) H = self._conv_layer[-1].out_height W = self._conv_layer[-1].out_width cm_shapes.append([C_out, H, W]) C_in = C_out self._param_shapes.extend(self._conv_layer[-1].param_shapes) if no_weights: self._hyper_shapes_learned.extend( \ self._conv_layer[-1].param_shapes) else: self._weights.extend(self._conv_layer[-1].weights) assert(len(self._conv_layer[-1].weights) == 2) self._layer_weight_tensors.append( \ self._conv_layer[-1].filters) self._layer_bias_vectors.append( \ self._conv_layer[-1].bias) ### Linear layers n_in = H * W * C_out assert(n_in == 6400) n = [1024, num_classes] for i, n_out in enumerate(n): W_shape = [n_out, n_in] b_shape = [n_out] # Note, that the last layer shape might not be used for context- # modulation. if i < (len(n)-1) or not no_last_layer_context_mod: cm_shapes.append([n_out]) n_in = n_out self._param_shapes.extend([W_shape, b_shape]) if no_weights: self._hyper_shapes_learned.extend([W_shape, b_shape]) else: W = nn.Parameter(torch.Tensor(*W_shape), requires_grad=True) b = nn.Parameter(torch.Tensor(*b_shape), requires_grad=True) init_params(W, b) self._weights.extend([W, b]) self._layer_weight_tensors.append(W) self._layer_bias_vectors.append(b) ### Define and initialize context mod weights. self._context_mod_layers = nn.ModuleList() if use_context_mod else None self._context_mod_shapes = [] if use_context_mod else None self._context_mod_weights = nn.ParameterList() if use_context_mod \ else None if use_context_mod: if not context_mod_apply_pixel_wise: # Only scalar gain and shift per feature map! for i, s in enumerate(cm_shapes): if len(s) == 3: cm_shapes[i] = [s[0], 1, 1] for i, s in enumerate(cm_shapes): cmod_layer = ContextModLayer(s, no_weights=context_mod_no_weights, apply_gain_offset=context_mod_gain_offset, apply_gain_softplus=context_mod_gain_softplus) self._context_mod_layers.append(cmod_layer) self._context_mod_shapes.extend(cmod_layer.param_shapes) if not context_mod_no_weights: self._context_mod_weights.extend(cmod_layer.weights) # We always had the context mod weights/shapes at the beginning of # our list attributes. self._param_shapes = self._context_mod_shapes + self._param_shapes if context_mod_no_weights: self._hyper_shapes_learned = self._context_mod_shapes + \ self._hyper_shapes_learned else: tmp = self._weights self._weights = nn.ParameterList(self._context_mod_weights) for w in tmp: self._weights.append(w) ### Apply custom init if given. if init_weights is not None: assert(len(self.weights) == len(init_weights)) for i in range(len(init_weights)): assert(np.all(np.equal(list(init_weights[i].shape), list(self._weights[i].shape)))) self._weights[i].data = init_weights[i] ### Print user info. num_weights = MainNetInterface.shapes_to_num_weights( \ self._param_shapes) if use_context_mod: cm_num_weights = MainNetInterface.shapes_to_num_weights( \ self._context_mod_shapes) print('Creating bio-plausible convnet with %d weights' % num_weights + (' (including %d weights associated with-' % cm_num_weights + 'context modulation)' if use_context_mod else '') + '.') self._is_properly_setup()
[docs] def forward(self, x, weights=None, distilled_params=None, condition=None, collect_activations=False): """Compute the output :math:`y` of this network given the input :math:`x`. Args: (....): See docstring of method :meth:`mnets.mnet_interface.MainNetInterface.forward`. We provide some more specific information below. x: Input image. .. note:: We assume the Tensorflow format, where the last entry denotes the number of channels. weights (list or dict): If a list of parameter tensors is given and context modulation is used (see argument ``use_context_mod`` in constructor), then these parameters are interpreted as context- modulation parameters if the length of ``weights`` equals :code:`2*len(net.context_mod_layers)`. Otherwise, the length is expected to be equal to the length of the attribute :attr:`mnets.mnet_interface.MainNetInterface.param_shapes`. Alternatively, a dictionary can be passed with the possible keywords ``internal_weights`` and ``mod_weights``. Each keyword is expected to map onto a list of tensors. The keyword ``internal_weights`` refers to all weights of this network except for the weights of the context-modulation layers. The keyword ``mod_weights``, on the other hand, refers specifically to the weights of the context-modulation layers. It is not necessary to specify both keywords. condition (int, optional): Will be passed as argument ``ckpt_id`` to the method :meth:`utils.context_mod_layer.ContextModLayer.forward` for all context-mod layers in this network. collect_activations (bool, optional): If one wants to return the activations in the network. This information can be used for credit assignment later on, in case an alternative to PyTorch its :mod:`torch.autograd` should be used. Returns: (:class:`torch.Tensor` or tuple): Tuple containing: - **y**: The output of the network. - **layer_activation** (optional): The activations of the network. Only returned if ``collect_activations`` was set to ``True``. The list will contain the activations of all convolutional and linear layers. """ if ((not self._use_context_mod and self._no_weights) or \ (self._no_weights or self._context_mod_no_weights)) and \ weights is None: raise Exception('Network was generated without weights. ' + 'Hence, "weights" option may not be None.') ############################################ ### Extract which weights should be used ### ############################################ # I.e., are we using internally maintained weights or externally given # ones or are we even mixing between these groups. # FIXME code mostly copied from MLP forward method. n_cm = 0 if self.context_mod_layers is None else \ 2 * len(self.context_mod_layers) if weights is None: weights = self.weights if self._use_context_mod: cm_weights = weights[:n_cm] int_weights = weights[n_cm:] else: int_weights = weights else: int_weights = None cm_weights = None if isinstance(weights, dict): assert('internal_weights' in weights.keys() or \ 'mod_weights' in weights.keys()) if 'internal_weights' in weights.keys(): int_weights = weights['internal_weights'] if 'mod_weights' in weights.keys(): cm_weights = weights['mod_weights'] else: if self._use_context_mod and \ len(weights) == n_cm: cm_weights = weights else: assert(len(weights) == len(self.param_shapes)) if self._use_context_mod: cm_weights = weights[:n_cm] int_weights = weights[n_cm:] else: int_weights = weights if self._use_context_mod and cm_weights is None: if self._context_mod_no_weights: raise Exception('Network was generated without weights ' + 'for context-mod layers. Hence, they must be passed ' + 'via the "weights" option.') cm_weights = self.weights[:n_cm] if int_weights is None: if self._no_weights: raise Exception('Network was generated without internal ' + 'weights. Hence, they must be passed via the ' + '"weights" option.') if self._context_mod_no_weights: int_weights = self.weights else: int_weights = self.weights[n_cm:] # Note, context-mod weights might have different shapes, as they # may be parametrized on a per-sample basis. if self._use_context_mod: assert(len(cm_weights) == len(self._context_mod_shapes)) int_shapes = self.param_shapes[n_cm:] assert(len(int_weights) == len(int_shapes)) for i, s in enumerate(int_shapes): assert(np.all(np.equal(s, list(int_weights[i].shape)))) ############################################### ### Extract weight tensors and bias vectors ### ############################################### w_weights = [] b_weights = [] for i, p in enumerate(int_weights): if self.has_bias and i % 2 == 1: b_weights.append(p) else: w_weights.append(p) ######################## ### Parse condition ### ####################### cmod_cond = condition ########################### ### Forward Computation ### ########################### cm_ind = 0 layer_ind = 0 x = x.view(-1, *self._in_shape) x = x.permute(0, 3, 1, 2) h = x activations = [] # Context-dependent modulation of inputs directly. if self._use_context_mod and self._context_mod_inputs: h = self._context_mod_layers[cm_ind].forward(h, weights=cm_weights[2*cm_ind:2*cm_ind+2], ckpt_id=cmod_cond) cm_ind += 1 # Convolutional layers. for i, conv_layer in enumerate(self._conv_layer): h = conv_layer.forward(h, weights=[w_weights[layer_ind], b_weights[layer_ind]]) layer_ind += 1 if collect_activations: activations.append(h.clone()) # Non-linearity (if context-dependent mod is applied post non-lin.) if self._context_mod_post_activation: h = torch.tanh(h) # Context-dependent modulation. if self._use_context_mod: h = self._context_mod_layers[cm_ind].forward(h, weights=cm_weights[2*cm_ind:2*cm_ind+2], ckpt_id=cmod_cond) cm_ind += 1 # Non-linearity (if context-dependent mod is applied pre non-lin.) if not self._context_mod_post_activation: h = torch.tanh(h) # Flatten feature maps. h = h.view(h.size(0), -1) # Linear layers. assert(len(w_weights) == layer_ind + 2) for _ in range(2): W = w_weights[layer_ind] if self.has_bias: b = b_weights[layer_ind] else: b = None # Linear layer. h = F.linear(h, W, bias=b) if collect_activations: activations.append(h.clone()) # Do not consider the output layer. if layer_ind < len(w_weights) - 1: # Non-linearity (if context-dependent mod is applied post # non-lin.) if self._context_mod_post_activation: h = torch.tanh(h) # Context-dependent modulation. if self._use_context_mod: h = self._context_mod_layers[cm_ind].forward(h, weights=cm_weights[2*cm_ind:2*cm_ind+2], ckpt_id=cmod_cond) cm_ind += 1 # Non-linearity (if context-dependent mod is applied pre # non-lin.) if not self._context_mod_post_activation: h = torch.tanh(h) if collect_activations and layer_ind == len(w_weights) - 2: last_hidden = h layer_ind += 1 # Context-dependent modulation in output layer. if self._use_context_mod and not self._no_last_layer_context_mod: h = self._context_mod_layers[cm_ind].forward(h, weights=cm_weights[2*cm_ind:2*cm_ind+2], ckpt_id=cmod_cond) if collect_activations: return h, activations, last_hidden else: return h
[docs] def distillation_targets(self): """Targets to be distilled after training. See docstring of abstract super method :meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`. This network does not have any distillation targets. Returns: ``None`` """ return None
if __name__ == '__main__': pass