Source code for hypnettorch.mnets.mlp

#!/usr/bin/env python3
# Copyright 2019 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title          :mnets/mlp.py
# @author         :ch
# @contact        :henningc@ethz.ch
# @created        :10/21/2019
# @version        :1.0
# @python_version :3.6.8
"""
Multi-Layer Perceptron
----------------------

Implementation of a fully-connected neural network.

An example usage is as a main model, that doesn't include any trainable weights.
Instead, weights are received as additional inputs. For instance, using an
auxilliary network, a so called hypernetwork, see

    Ha et al., "HyperNetworks", arXiv, 2016,
    https://arxiv.org/abs/1609.09106
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from hypnettorch.mnets.mnet_interface import MainNetInterface
from hypnettorch.utils.batchnorm_layer import BatchNormLayer
from hypnettorch.utils.context_mod_layer import ContextModLayer
from hypnettorch.utils.torch_utils import init_params

[docs]class MLP(nn.Module, MainNetInterface): """Implementation of a Multi-Layer Perceptron (MLP). This is a simple fully-connected network, that receives input vector :math:`\mathbf{x}` and outputs a vector :math:`\mathbf{y}` of real values. The output mapping does not include a non-linearity by default, as we wanna map to the whole real line (but see argument ``out_fn``). Args: n_in (int): Number of inputs. n_out (int): Number of outputs. hidden_layers (list or tuple): A list of integers, each number denoting the size of a hidden layer. activation_fn: The nonlinearity used in hidden layers. If ``None``, no nonlinearity will be applied. use_bias (bool): Whether layers may have bias terms. no_weights (bool): If set to ``True``, no trainable parameters will be constructed, i.e., weights are assumed to be produced ad-hoc by a hypernetwork and passed to the :meth:`forward` method. init_weights (optional): This option is for convinience reasons. The option expects a list of parameter values that are used to initialize the network weights. As such, it provides a convinient way of initializing a network with a weight draw produced by the hypernetwork. Note, internal weights (see :attr:`mnets.mnet_interface.MainNetInterface.weights`) will be affected by this argument only. dropout_rate: If ``-1``, no dropout will be applied. Otherwise a number between 0 and 1 is expected, denoting the dropout rate of hidden layers. use_spectral_norm: Use spectral normalization for training. use_batch_norm (bool): Whether batch normalization should be used. Will be applied before the activation function in all hidden layers. bn_track_stats (bool): If batch normalization is used, then this option determines whether running statistics are tracked in these layers or not (see argument ``track_running_stats`` of class :class:`utils.batchnorm_layer.BatchNormLayer`). If ``False``, then batch statistics are utilized even during evaluation. If ``True``, then running stats are tracked. When using this network in a continual learning scenario with different tasks then the running statistics are expected to be maintained externally. The argument ``stats_id`` of the method :meth:`utils.batchnorm_layer.BatchNormLayer.forward` can be provided using the argument ``condition`` of method :meth:`forward`. Example: To maintain the running stats, one can simply iterate over all batch norm layers and checkpoint the current running stats (e.g., after learning a task when applying a Continual learning scenario). .. code:: python for bn_layer in net.batchnorm_layers: bn_layer.checkpoint_stats() distill_bn_stats (bool): If ``True``, then the shapes of the batchnorm statistics will be added to the attribute :attr:`mnets.mnet_interface.MainNetInterface.\ hyper_shapes_distilled` and the current statistics will be returned by the method :meth:`distillation_targets`. Note, this attribute may only be ``True`` if ``bn_track_stats`` is ``True``. use_context_mod (bool): Add context-dependent modulation layers :class:`utils.context_mod_layer.ContextModLayer` after the linear computation of each layer. context_mod_inputs (bool): Whether context-dependent modulation should also be applied to network intpus directly. I.e., assume :math:`\mathbf{x}` is the input to the network. Then the first network operation would be to modify the input via :math:`\mathbf{x} \cdot \mathbf{g} + \mathbf{s}` using context- dependent gain and shift parameters. Note: Argument applies only if ``use_context_mod`` is ``True``. no_last_layer_context_mod (bool): If ``True``, context-dependent modulation will not be applied to the output layer. Note: Argument applies only if ``use_context_mod`` is ``True``. context_mod_no_weights (bool): The weights of the context-mod layers (:class:`utils.context_mod_layer.ContextModLayer`) are treated independently of the option ``no_weights``. This argument can be used to decide whether the context-mod parameters (gains and shifts) are maintained internally or externally. Note: Check out argument ``weights`` of the :meth:`forward` method on how to correctly pass weights to the network that are externally maintained. context_mod_post_activation (bool): Apply context-mod layers after the activation function (``activation_fn``) in hidden layer rather than before, which is the default behavior. Note: This option only applies if ``use_context_mod`` is ``True``. Note: This option does not affect argument ``context_mod_inputs``. Note: This option does not affect argument ``no_last_layer_context_mod``. Hence, if a output-nonlinearity is applied through argument ``out_fn``, then context-modulation would be applied before this non-linearity. context_mod_gain_offset (bool): Activates option ``apply_gain_offset`` of class :class:`utils.context_mod_layer.ContextModLayer` for all context-mod layers that will be instantiated. context_mod_gain_softplus (bool): Activates option ``apply_gain_softplus`` of class :class:`utils.context_mod_layer.ContextModLayer` for all context-mod layers that will be instantiated. out_fn (optional): If provided, this function will be applied to the output neurons of the network. Warning: This changes the interpretation of the output of the :meth:`forward` method. verbose (bool): Whether to print information (e.g., the number of weights) during the construction of the network. """ def __init__(self, n_in=1, n_out=1, hidden_layers=(10, 10), activation_fn=torch.nn.ReLU(), use_bias=True, no_weights=False, init_weights=None, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False, bn_track_stats=True, distill_bn_stats=False, use_context_mod=False, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=False, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_gain_softplus=False, out_fn=None, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) # FIXME Spectral norm is incorrectly implemented. Function # `nn.utils.spectral_norm` needs to be called in the constructor, such # that sepc norm is wrapped around a module. if use_spectral_norm: raise NotImplementedError('Spectral normalization not yet ' + 'implemented for this network.') if use_batch_norm and use_context_mod: # FIXME Does it make sense to have both enabled? # I.e., should we produce a warning or error? pass # Tuple are not mutable. hidden_layers = list(hidden_layers) self._a_fun = activation_fn assert(init_weights is None or \ (not no_weights or not context_mod_no_weights)) self._no_weights = no_weights self._dropout_rate = dropout_rate #self._use_spectral_norm = use_spectral_norm self._use_batch_norm = use_batch_norm self._bn_track_stats = bn_track_stats self._distill_bn_stats = distill_bn_stats and use_batch_norm self._use_context_mod = use_context_mod self._context_mod_inputs = context_mod_inputs self._no_last_layer_context_mod = no_last_layer_context_mod self._context_mod_no_weights = context_mod_no_weights self._context_mod_post_activation = context_mod_post_activation self._context_mod_gain_offset = context_mod_gain_offset self._context_mod_gain_softplus = context_mod_gain_softplus self._out_fn = out_fn self._has_bias = use_bias self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True if out_fn is None else False if use_spectral_norm and no_weights: raise ValueError('Cannot use spectral norm in a network without ' + 'parameters.') # FIXME make sure that this implementation is correct in all situations # (e.g., what to do if weights are passed to the forward method?). if use_spectral_norm: self._spec_norm = nn.utils.spectral_norm else: self._spec_norm = lambda x : x # identity self._param_shapes = [] self._param_shapes_meta = [] self._weights = None if no_weights and context_mod_no_weights \ else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not context_mod_no_weights else [] self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ is None else [] if dropout_rate != -1: assert(dropout_rate >= 0. and dropout_rate <= 1.) self._dropout = nn.Dropout(p=dropout_rate) ### Define and initialize context mod weights. self._context_mod_layers = nn.ModuleList() if use_context_mod else None self._context_mod_shapes = [] if use_context_mod else None if use_context_mod: cm_ind = 0 cm_sizes = [] if context_mod_inputs: cm_sizes.append(n_in) cm_sizes.extend(hidden_layers) if not no_last_layer_context_mod: cm_sizes.append(n_out) for i, n in enumerate(cm_sizes): cmod_layer = ContextModLayer(n, no_weights=context_mod_no_weights, apply_gain_offset=context_mod_gain_offset, apply_gain_softplus=context_mod_gain_softplus) self._context_mod_layers.append(cmod_layer) self.param_shapes.extend(cmod_layer.param_shapes) assert len(cmod_layer.param_shapes) == 2 self._param_shapes_meta.extend([ {'name': 'cm_scale', 'index': -1 if context_mod_no_weights else \ len(self._weights), 'layer': -1}, # 'layer' is set later. {'name': 'cm_shift', 'index': -1 if context_mod_no_weights else \ len(self._weights)+1, 'layer': -1}, # 'layer' is set later. ]) self._context_mod_shapes.extend(cmod_layer.param_shapes) if context_mod_no_weights: self._hyper_shapes_learned.extend(cmod_layer.param_shapes) else: self._weights.extend(cmod_layer.weights) # FIXME ugly code. Move initialization somewhere else. if not context_mod_no_weights and init_weights is not None: assert(len(cmod_layer.weights) == 2) for ii in range(2): assert(np.all(np.equal( \ list(init_weights[cm_ind].shape), list(cm_ind.weights[ii].shape)))) cmod_layer.weights[ii].data = init_weights[cm_ind] cm_ind += 1 if init_weights is not None: init_weights = init_weights[cm_ind:] if context_mod_no_weights: self._hyper_shapes_learned_ref = \ list(range(len(self._param_shapes))) ### Define and initialize batch norm weights. self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None if use_batch_norm: if distill_bn_stats: self._hyper_shapes_distilled = [] bn_ind = 0 for i, n in enumerate(hidden_layers): bn_layer = BatchNormLayer(n, affine=not no_weights, track_running_stats=bn_track_stats) self._batchnorm_layers.append(bn_layer) self._param_shapes.extend(bn_layer.param_shapes) assert len(bn_layer.param_shapes) == 2 self._param_shapes_meta.extend([ {'name': 'bn_scale', 'index': -1 if no_weights else len(self._weights), 'layer': -1}, # 'layer' is set later. {'name': 'bn_shift', 'index': -1 if no_weights else len(self._weights)+1, 'layer': -1}, # 'layer' is set later. ]) if no_weights: self._hyper_shapes_learned.extend(bn_layer.param_shapes) else: self._weights.extend(bn_layer.weights) if distill_bn_stats: self._hyper_shapes_distilled.extend( \ [list(p.shape) for p in bn_layer.get_stats(0)]) # FIXME ugly code. Move initialization somewhere else. if not no_weights and init_weights is not None: assert(len(bn_layer.weights) == 2) for ii in range(2): assert(np.all(np.equal( \ list(init_weights[bn_ind].shape), list(bn_layer.weights[ii].shape)))) bn_layer.weights[ii].data = init_weights[bn_ind] bn_ind += 1 if init_weights is not None: init_weights = init_weights[bn_ind:] ### Compute shapes of linear layers. linear_shapes = MLP.weight_shapes(n_in=n_in, n_out=n_out, hidden_layers=hidden_layers, use_bias=use_bias) self._param_shapes.extend(linear_shapes) for i, s in enumerate(linear_shapes): self._param_shapes_meta.append({ 'name': 'weight' if len(s) != 1 else 'bias', 'index': -1 if no_weights else len(self._weights) + i, 'layer': -1 # 'layer' is set later. }) num_weights = MainNetInterface.shapes_to_num_weights(self._param_shapes) ### Set missing meta information of param_shapes. offset = 1 if use_context_mod and context_mod_inputs else 0 shift = 1 if use_batch_norm: shift += 1 if use_context_mod: shift += 1 cm_offset = 2 if context_mod_post_activation else 1 bn_offset = 1 if context_mod_post_activation else 2 cm_ind = 0 bn_ind = 0 layer_ind = 0 for i, dd in enumerate(self._param_shapes_meta): if dd['name'].startswith('cm'): if offset == 1 and i in [0, 1]: dd['layer'] = 0 else: if cm_ind < len(hidden_layers): dd['layer'] = offset + cm_ind * shift + cm_offset else: assert cm_ind == len(hidden_layers) and \ not no_last_layer_context_mod # No batchnorm in output layer. dd['layer'] = offset + cm_ind * shift + 1 if dd['name'] == 'cm_shift': cm_ind += 1 elif dd['name'].startswith('bn'): dd['layer'] = offset + bn_ind * shift + bn_offset if dd['name'] == 'bn_shift': bn_ind += 1 else: dd['layer'] = offset + layer_ind * shift if not use_bias or dd['name'] == 'bias': layer_ind += 1 ### Uer information if verbose: if use_context_mod: cm_num_weights = 0 for cm_layer in self._context_mod_layers: cm_num_weights += MainNetInterface.shapes_to_num_weights( \ cm_layer.param_shapes) print('Creating an MLP with %d weights' % num_weights + (' (including %d weights associated with-' % cm_num_weights + 'context modulation)' if use_context_mod else '') + '.' + (' The network uses dropout.' if dropout_rate != -1 else '') + (' The network uses batchnorm.' if use_batch_norm else '')) self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() if no_weights: self._hyper_shapes_learned.extend(linear_shapes) if use_context_mod: if context_mod_no_weights: self._hyper_shapes_learned_ref = \ list(range(len(self._param_shapes))) else: ncm = len(self._context_mod_shapes) self._hyper_shapes_learned_ref = \ list(range(ncm, len(self._param_shapes))) self._is_properly_setup() return ### Define and initialize linear weights. for i, dims in enumerate(linear_shapes): self._weights.append(nn.Parameter(torch.Tensor(*dims), requires_grad=True)) if len(dims) == 1: self._layer_bias_vectors.append(self._weights[-1]) else: self._layer_weight_tensors.append(self._weights[-1]) if init_weights is not None: assert(len(init_weights) == len(linear_shapes)) for i in range(len(init_weights)): assert(np.all(np.equal(list(init_weights[i].shape), linear_shapes[i]))) if use_bias: if i % 2 == 0: self._layer_weight_tensors[i//2].data = init_weights[i] else: self._layer_bias_vectors[i//2].data = init_weights[i] else: self._layer_weight_tensors[i].data = init_weights[i] else: for i in range(len(self._layer_weight_tensors)): if use_bias: init_params(self._layer_weight_tensors[i], self._layer_bias_vectors[i]) else: init_params(self._layer_weight_tensors[i]) if self._num_context_mod_shapes() == 0: # Note, that might be the case if no hidden layers exist and no # input or output modulation is used. self._use_context_mod = False self._is_properly_setup()
[docs] def forward(self, x, weights=None, distilled_params=None, condition=None): """Compute the output :math:`y` of this network given the input :math:`x`. Args: (....): See docstring of method :meth:`mnets.mnet_interface.MainNetInterface.forward`. We provide some more specific information below. weights (list or dict): If a list of parameter tensors is given and context modulation is used (see argument ``use_context_mod`` in constructor), then these parameters are interpreted as context- modulation parameters if the length of ``weights`` equals :code:`2*len(net.context_mod_layers)`. Otherwise, the length is expected to be equal to the length of the attribute :attr:`mnets.mnet_interface.MainNetInterface.param_shapes`. Alternatively, a dictionary can be passed with the possible keywords ``internal_weights`` and ``mod_weights``. Each keyword is expected to map onto a list of tensors. The keyword ``internal_weights`` refers to all weights of this network except for the weights of the context-modulation layers. The keyword ``mod_weights``, on the other hand, refers specifically to the weights of the context-modulation layers. It is not necessary to specify both keywords. distilled_params: Will be passed as ``running_mean`` and ``running_var`` arguments of method :meth:`utils.batchnorm_layer.BatchNormLayer.forward` if batch normalization is used. condition (int or dict, optional): If ``int`` is provided, then this argument will be passed as argument ``stats_id`` to the method :meth:`utils.batchnorm_layer.BatchNormLayer.forward` if batch normalization is used. If a ``dict`` is provided instead, the following keywords are allowed: - ``bn_stats_id``: Will be handled as ``stats_id`` of the batchnorm layers as described above. - ``cmod_ckpt_id``: Will be passed as argument ``ckpt_id`` to the method :meth:`utils.context_mod_layer.ContextModLayer.forward`. Returns: (tuple): Tuple containing: - **y**: The output of the network. - **h_y** (optional): If ``out_fn`` was specified in the constructor, then this value will be returned. It is the last hidden activation (before the ``out_fn`` has been applied). """ if ((not self._use_context_mod and self._no_weights) or \ (self._no_weights or self._context_mod_no_weights)) and \ weights is None: raise Exception('Network was generated without weights. ' + 'Hence, "weights" option may not be None.') ############################################ ### Extract which weights should be used ### ############################################ # I.e., are we using internally maintained weights or externally given # ones or are we even mixing between these groups. n_cm = self._num_context_mod_shapes() if weights is None: weights = self.weights if self._use_context_mod: cm_weights = weights[:n_cm] int_weights = weights[n_cm:] else: int_weights = weights else: int_weights = None cm_weights = None if isinstance(weights, dict): assert('internal_weights' in weights.keys() or \ 'mod_weights' in weights.keys()) if 'internal_weights' in weights.keys(): int_weights = weights['internal_weights'] if 'mod_weights' in weights.keys(): cm_weights = weights['mod_weights'] else: if self._use_context_mod and \ len(weights) == n_cm: cm_weights = weights else: assert(len(weights) == len(self.param_shapes)) if self._use_context_mod: cm_weights = weights[:n_cm] int_weights = weights[n_cm:] else: int_weights = weights if self._use_context_mod and cm_weights is None: if self._context_mod_no_weights: raise Exception('Network was generated without weights ' + 'for context-mod layers. Hence, they must be passed ' + 'via the "weights" option.') cm_weights = self.weights[:n_cm] if int_weights is None: if self._no_weights: raise Exception('Network was generated without internal ' + 'weights. Hence, they must be passed via the ' + '"weights" option.') if self._context_mod_no_weights: int_weights = self.weights else: int_weights = self.weights[n_cm:] # Note, context-mod weights might have different shapes, as they # may be parametrized on a per-sample basis. if self._use_context_mod: assert(len(cm_weights) == len(self._context_mod_shapes)) int_shapes = self.param_shapes[n_cm:] assert(len(int_weights) == len(int_shapes)) for i, s in enumerate(int_shapes): assert(np.all(np.equal(s, list(int_weights[i].shape)))) cm_ind = 0 bn_ind = 0 if self._use_batch_norm: n_bn = 2 * len(self.batchnorm_layers) bn_weights = int_weights[:n_bn] layer_weights = int_weights[n_bn:] else: layer_weights = int_weights w_weights = [] b_weights = [] for i, p in enumerate(layer_weights): if self.has_bias and i % 2 == 1: b_weights.append(p) else: w_weights.append(p) ######################## ### Parse condition ### ####################### bn_cond = None cmod_cond = None if condition is not None: if isinstance(condition, dict): assert('bn_stats_id' in condition.keys() or \ 'cmod_ckpt_id' in condition.keys()) if 'bn_stats_id' in condition.keys(): bn_cond = condition['bn_stats_id'] if 'cmod_ckpt_id' in condition.keys(): cmod_cond = condition['cmod_ckpt_id'] # FIXME We always require context-mod weight above, but # we can't pass both (a condition and weights) to the # context-mod layers. # An unelegant solution would be, to just set all # context-mod weights to None. raise NotImplementedError('CM-conditions not implemented!') else: bn_cond = condition ###################################### ### Select batchnorm running stats ### ###################################### if self._use_batch_norm: nn = len(self._batchnorm_layers) running_means = [None] * nn running_vars = [None] * nn if distilled_params is not None: if not self._distill_bn_stats: raise ValueError('Argument "distilled_params" can only be ' + 'provided if the return value of ' + 'method "distillation_targets()" is not None.') shapes = self.hyper_shapes_distilled assert(len(distilled_params) == len(shapes)) for i, s in enumerate(shapes): assert(np.all(np.equal(s, list(distilled_params[i].shape)))) # Extract batchnorm stats from distilled_params for i in range(0, len(distilled_params), 2): running_means[i//2] = distilled_params[i] running_vars[i//2] = distilled_params[i+1] elif self._use_batch_norm and self._bn_track_stats and \ bn_cond is None: for i, bn_layer in enumerate(self._batchnorm_layers): running_means[i], running_vars[i] = bn_layer.get_stats() ########################### ### Forward Computation ### ########################### hidden = x # Context-dependent modulation of inputs directly. if self._use_context_mod and self._context_mod_inputs: hidden = self._context_mod_layers[cm_ind].forward(hidden, weights=cm_weights[2*cm_ind:2*cm_ind+2], ckpt_id=cmod_cond) cm_ind += 1 for l in range(len(w_weights)): W = w_weights[l] if self.has_bias: b = b_weights[l] else: b = None # Linear layer. hidden = self._spec_norm(F.linear(hidden, W, bias=b)) # Only for hidden layers. if l < len(w_weights) - 1: # Context-dependent modulation (pre-activation). if self._use_context_mod and \ not self._context_mod_post_activation: hidden = self._context_mod_layers[cm_ind].forward(hidden, weights=cm_weights[2*cm_ind:2*cm_ind+2], ckpt_id=cmod_cond) cm_ind += 1 # Batch norm if self._use_batch_norm: hidden = self._batchnorm_layers[bn_ind].forward(hidden, running_mean=running_means[bn_ind], running_var=running_vars[bn_ind], weight=bn_weights[2*bn_ind], bias=bn_weights[2*bn_ind+1], stats_id=bn_cond) bn_ind += 1 # Dropout if self._dropout_rate != -1: hidden = self._dropout(hidden) # Non-linearity if self._a_fun is not None: hidden = self._a_fun(hidden) # Context-dependent modulation (post-activation). if self._use_context_mod and self._context_mod_post_activation: hidden = self._context_mod_layers[cm_ind].forward(hidden, weights=cm_weights[2*cm_ind:2*cm_ind+2], ckpt_id=cmod_cond) cm_ind += 1 # Context-dependent modulation in output layer. if self._use_context_mod and not self._no_last_layer_context_mod: hidden = self._context_mod_layers[cm_ind].forward(hidden, weights=cm_weights[2*cm_ind:2*cm_ind+2], ckpt_id=cmod_cond) if self._out_fn is not None: return self._out_fn(hidden), hidden return hidden
[docs] def distillation_targets(self): """Targets to be distilled after training. See docstring of abstract super method :meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`. This method will return the current batch statistics of all batch normalization layers if ``distill_bn_stats`` and ``use_batch_norm`` was set to ``True`` in the constructor. Returns: The target tensors corresponding to the shapes specified in attribute :attr:`hyper_shapes_distilled`. """ if self.hyper_shapes_distilled is None: return None ret = [] for bn_layer in self._batchnorm_layers: ret.extend(bn_layer.get_stats()) return ret
[docs] @staticmethod def weight_shapes(n_in=1, n_out=1, hidden_layers=[10, 10], use_bias=True): """Compute the tensor shapes of all parameters in a fully-connected network. Args: n_in: Number of inputs. n_out: Number of output units. hidden_layers: A list of ints, each number denoting the size of a hidden layer. use_bias: Whether the FC layers should have biases. Returns: A list of list of integers, denoting the shapes of the individual parameter tensors. """ shapes = [] prev_dim = n_in layer_out_sizes = hidden_layers + [n_out] for i, size in enumerate(layer_out_sizes): shapes.append([size, prev_dim]) if use_bias: shapes.append([size]) prev_dim = size return shapes
if __name__ == '__main__': pass