Source code for hypnettorch.hnets.mlp_hnet

#!/usr/bin/env python3
# Copyright 2020 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title          :hnets/mlp_hnet.py
# @author         :ch
# @contact        :henningc@ethz.ch
# @created        :04/14/2020
# @version        :1.0
# @python_version :3.6.10
"""
MLP - Hypernetwork
------------------

The module :mod:`hnets.mlp_hnet` contains a fully-connected hypernetwork
(also termed `full hypernet`).

This type of hypernetwork represents one of the most simplistic architectural
choices to realize a weight generator. An embedding input, which may consists of
conditional and unconditional parts (for instance, in the case of
`task-conditioned hypernetwork <https://arxiv.org/abs/1906.00695>`__ the
conditional input will be a task embedding) is mapped via a series of fully-
connected layers onto a final hidden representation. Then a linear
fully-connected output layer per is used to produce the target weights, output
tensors with shapes specified via the target shapes (see
:attr:`hnets.hnet_interface.HyperNetInterface.target_shapes`).

If no hidden layers are used, then this resembles a simplistic linear
hypernetwork, where the input embeddings are linearly mapped onto target
weights.
"""
from collections import defaultdict
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from warnings import warn

from hypnettorch.hnets.hnet_interface import HyperNetInterface
from hypnettorch.mnets.mnet_interface import MainNetInterface
from hypnettorch.utils import init_utils as iutils

[docs]class HMLP(nn.Module, HyperNetInterface): """Implementation of a `full hypernet`. The network will consist of several hidden layers and a final linear output layer that produces all weight matrices/bias-vectors the network has to produce. The network allows to maintain a set of embeddings internally that can be used as conditional input. Args: target_shapes (list): List of lists of intergers, i.e., a list of tensor shapes. Those will be the shapes of the output weights produced by the hypernetwork. For each entry in this list, a separate output layer will be instantiated. uncond_in_size (int): The size of unconditional inputs (for instance, noise). cond_in_size (int): The size of conditional input embeddings. Note, if ``no_cond_weights`` is ``False``, those embeddings will be maintained internally. layers (list or tuple): List of integers denoteing the sizes of each hidden layer. If empty, no hidden layers will be produced. verbose (bool): Whether network information should be printed during network creation. activation_fn (func): The activation function to be used for hidden activations. For instance, an instance of class :class:`torch.nn.ReLU`. use_bias (bool): Whether the fully-connected layers that make up this network should have bias vectors. no_uncond_weights (bool): If ``True``, unconditional weights are not maintained internally and instead expected to be produced externally and passed to the :meth:`forward`. no_cond_weights (bool): If ``True``, conditional embeddings are assumed to be maintained externally. Otherwise, option ``num_cond_embs`` has to be properly set, which will determine the number of embeddings that are internally maintained. num_cond_embs (int): Number of conditional embeddings to be internally maintained. Only used if option ``no_cond_weights`` is ``False``. Note: Embeddings will be initialized with a normal distribution using zero mean and unit variance. dropout_rate (float): If ``-1``, no dropout will be applied. Otherwise a number between 0 and 1 is expected, denoting the dropout rate of hidden layers. use_spectral_norm (bool): Use spectral normalization for training. use_batch_norm (bool): Whether batch normalization should be used. Will be applied before the activation function in all hidden layers. Note: Batch norm only makes sense if the hypernetwork is envoked with batch sizes greater than 1 during training. """ def __init__(self, target_shapes, uncond_in_size=0, cond_in_size=8, layers=(100, 100), verbose=True, activation_fn=torch.nn.ReLU(), use_bias=True, no_uncond_weights=False, no_cond_weights=False, num_cond_embs=1, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) HyperNetInterface.__init__(self) if use_spectral_norm: raise NotImplementedError('Spectral normalization not yet ' + 'implemented for this hypernetwork type.') assert len(target_shapes) > 0 if cond_in_size == 0 and num_cond_embs > 0: warn('Requested that conditional weights are managed, but ' + 'conditional input size is zero! Setting "num_cond_embs" to ' + 'zero.') num_cond_embs = 0 elif not no_cond_weights and num_cond_embs == 0 and cond_in_size > 0: warn('Requested that conditional weights are internally ' + 'maintained, but "num_cond_embs" is zero.') # Do we maintain conditional weights internally? has_int_cond_weights = not no_cond_weights and num_cond_embs > 0 # Do we expect external conditional weights? has_ext_cond_weights = no_cond_weights and num_cond_embs > 0 ### Make constructor arguments internally available ### self._uncond_in_size = uncond_in_size self._cond_in_size = cond_in_size self._layers = layers self._act_fn = activation_fn self._no_uncond_weights = no_uncond_weights self._no_cond_weights = no_cond_weights self._num_cond_embs = num_cond_embs self._dropout_rate = dropout_rate self._use_spectral_norm = use_spectral_norm self._use_batch_norm = use_batch_norm ### Setup attributes required by interface ### self._target_shapes = target_shapes self._num_known_conds = self._num_cond_embs self._unconditional_param_shapes_ref = [] self._has_bias = use_bias self._has_fc_out = True self._mask_fc_out = True self._has_linear_out = True self._param_shapes = [] self._param_shapes_meta = [] self._internal_params = None if no_uncond_weights and \ has_int_cond_weights else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_uncond_weights and has_ext_cond_weights else [] self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ is None else [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() self._dropout = None if dropout_rate != -1: assert dropout_rate > 0 and dropout_rate < 1 self._dropout = nn.Dropout(dropout_rate) ### Create conditional weights ### for _ in range(num_cond_embs): assert cond_in_size > 0 if not no_cond_weights: self._internal_params.append(nn.Parameter( \ data=torch.Tensor(cond_in_size), requires_grad=True)) torch.nn.init.normal_(self._internal_params[-1], mean=0., std=1.) else: self._hyper_shapes_learned.append([cond_in_size]) self._hyper_shapes_learned_ref.append(len(self.param_shapes)) self._param_shapes.append([cond_in_size]) # Embeddings belong to the input, so we just assign them all to # "layer" 0. self._param_shapes_meta.append({ 'name': 'embedding', 'index': -1 if no_cond_weights else \ len(self._internal_params)-1, 'layer': 0 }) ### Create batch-norm layers ### # We just use even numbers starting from 2 as layer indices for # batchnorm layers. if use_batch_norm: self._add_batchnorm_layers(layers, no_uncond_weights, bn_layers=list(range(2, 2*len(layers)+1, 2)), distill_bn_stats=False, bn_track_stats=True) ### Create fully-connected hidden-layers ### in_size = uncond_in_size + cond_in_size if len(layers) > 0: # We use odd numbers starting at 1 as layer indices for hidden # layers. self._add_fc_layers([in_size, *layers[:-1]], layers, no_uncond_weights, fc_layers=list(range(1, 2*len(layers), 2))) hidden_size = layers[-1] else: hidden_size = in_size ### Create fully-connected output-layers ### # Note, technically there is no difference between having a separate # fully-connected layer per target shape or a single fully-connected # layer producing all weights at once (in any case, each output is # connceted to all hidden units). # I guess it is more computationally efficient to have one output layer # and then split the output according to the target shapes. self._add_fc_layers([hidden_size], [self.num_outputs], no_uncond_weights, fc_layers=[2*len(layers)+1]) ### Finalize construction ### # All parameters are unconditional except the embeddings created at the # very beginning. self._unconditional_param_shapes_ref = \ list(range(num_cond_embs, len(self.param_shapes))) self._is_properly_setup() if verbose: print('Created MLP Hypernet.') print(self)
[docs] def forward(self, uncond_input=None, cond_input=None, cond_id=None, weights=None, distilled_params=None, condition=None, ret_format='squeezed'): """Compute the weights of a target network. Args: (....): See docstring of method :meth:`hnets.hnet_interface.HyperNetInterface.forward`. condition (int, optional): This argument will be passed as argument ``stats_id`` to the method :meth:`utils.batchnorm_layer.BatchNormLayer.forward` if batch normalization is used. Returns: (list or torch.Tensor): See docstring of method :meth:`hnets.hnet_interface.HyperNetInterface.forward`. """ uncond_input, cond_input, uncond_weights, _ = \ self._preprocess_forward_args(uncond_input=uncond_input, cond_input=cond_input, cond_id=cond_id, weights=weights, distilled_params=distilled_params, condition=condition, ret_format=ret_format) ### Prepare hypernet input ### assert self._uncond_in_size == 0 or uncond_input is not None assert self._cond_in_size == 0 or cond_input is not None if uncond_input is not None: assert len(uncond_input.shape) == 2 and \ uncond_input.shape[1] == self._uncond_in_size h = uncond_input if cond_input is not None: assert len(cond_input.shape) == 2 and \ cond_input.shape[1] == self._cond_in_size h = cond_input if uncond_input is not None and cond_input is not None: h = torch.cat([uncond_input, cond_input], dim=1) ### Extract layer weights ### bn_scales = [] bn_shifts = [] fc_weights = [] fc_biases = [] assert len(uncond_weights) == len(self.unconditional_param_shapes_ref) for i, idx in enumerate(self.unconditional_param_shapes_ref): meta = self.param_shapes_meta[idx] if meta['name'] == 'bn_scale': bn_scales.append(uncond_weights[i]) elif meta['name'] == 'bn_shift': bn_shifts.append(uncond_weights[i]) elif meta['name'] == 'weight': fc_weights.append(uncond_weights[i]) else: assert meta['name'] == 'bias' fc_biases.append(uncond_weights[i]) if not self.has_bias: assert len(fc_biases) == 0 fc_biases = [None] * len(fc_weights) if self._use_batch_norm: assert len(bn_scales) == len(fc_weights) - 1 ### Process inputs through network ### for i in range(len(fc_weights)): last_layer = i == (len(fc_weights) - 1) h = F.linear(h, fc_weights[i], bias=fc_biases[i]) if not last_layer: # Batch-norm if self._use_batch_norm: h = self.batchnorm_layers[i].forward(h, running_mean=None, running_var=None, weight=bn_scales[i], bias=bn_shifts[i], stats_id=condition) # Dropout if self._dropout_rate != -1: h = self._dropout(h) # Non-linearity if self._act_fn is not None: h = self._act_fn(h) ### Split output into target shapes ### ret = self._flat_to_ret_format(h, ret_format) return ret
[docs] def distillation_targets(self): """Targets to be distilled after training. See docstring of abstract super method :meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`. This network does not have any distillation targets. Returns: ``None`` """ return None
[docs] def apply_hyperfan_init(self, method='in', use_xavier=False, uncond_var=1., cond_var=1., mnet=None, w_val=None, w_var=None, b_val=None, b_var=None): r"""Initialize the network using `hyperfan init`. Hyperfan initialization was developed in the following paper for this kind of hypernetwork "Principled Weight Initialization for Hypernetworks" https://openreview.net/forum?id=H1lma24tPB The initialization is based on the following idea: When the main network would be initialized using Xavier or Kaiming init, then variance of activations (fan-in) or gradients (fan-out) would be preserved by using a proper variance for the initial weight distribution (assuming certain assumptions hold at initialization, which are different for Xavier and Kaiming). When using this kind of initializations in the hypernetwork, then the variance of the initial main net weight distribution would simply equal the variance of the input embeddings (which can lead to exploding activations, e.g., for fan-in inits). The above mentioned paper proposes a quick fix for the type of hypernet that resembles the simple MLP hnet implemented in this class, i.e., which have a separate output head per weight tensor in the main network. Assuming that input embeddings are initialized with a certain variance (e.g., 1) and we use Xavier or Kaiming init for the hypernet, then the variance of the last hidden activation will also be 1. Then, we can modify the variance of the weights of each output head in the hypernet to obtain the same variance per main net weight tensor that we would typically obtain when applying Xavier or Kaiming to the main network directly. Note: If ``mnet`` is not provided or the corresponding attribute :attr:`mnets.mnet_interface.MainNetInterface.param_shapes_meta` is not implemented, then this method assumes that 1D target tensors (cf. constructor argument ``target_shapes``) represent bias vectors in the main network. Note: To compute the hyperfan-out initialization of bias vectors, we need access to the fan-in of the layer, which we can only compute based on the corresponding weight tensor in the same layer. This is only possible if ``mnet`` is provided. Otherwise, the following heuristic is applied. We assume that the shape directly preceding a bias shape in the constructor argument ``target_shapes`` is the corresponding weight tensor. Note: All hypernet inputs are assumed to be zero-mean random variables. **Variance of the hypernet input** In general, the input to the hypernetwork can be a concatenation of multiple embeddings (see description of arguments ``uncond_var`` and ``cond_var``). Let's denote the complete hypernetwork input by :math:`\mathbf{x} \in \mathbb{R}^n`, which consists of a conditional embedding :math:`\mathbf{e} \in \mathbb{R}^{n_e}` and an unconditional input :math:`\mathbf{c} \in \mathbb{R}^{n_c}`, i.e., .. math:: \mathbf{x} = \begin{bmatrix} \ \mathbf{e} \\ \ \mathbf{c} \ \end{bmatrix} We simply define the variance of an input :math:`\text{Var}(x_j)` as the weighted average of the individual variances, i.e., .. math:: \text{Var}(x_j) \equiv \frac{n_e}{n_e+n_c} \text{Var}(e) + \ \frac{n_c}{n_e+n_c} \text{Var}(c) To see that this is correct, consider a linear layer :math:`\mathbf{y} = W \mathbf{x}` or .. math:: y_i &= \sum_j w_{ij} x_j \\ \ &= \sum_{j=1}^{n_e} w_{ij} e_j + \ \sum_{j=n_e+1}^{n_e+n_c} w_{ij} c_{j-n_e} Hence, we can compute the variance of :math:`y_i` as follows (assuming the typical Xavier assumptions): .. math:: \text{Var}(y) &= n_e \text{Var}(w) \text{Var}(e) + \ n_c \text{Var}(w) \text{Var}(c) \\ \ &= \frac{n_e}{n_e+n_c} \text{Var}(e) + \ \frac{n_c}{n_e+n_c} \text{Var}(c) Note, that Xavier would have initialized :math:`W` using :math:`\text{Var}(w) = \frac{1}{n} = \frac{1}{n_e+n_c}`. Args: method (str): The type of initialization that should be applied. Possible options are: - ``'in'``: Use `Hyperfan-in`. - ``'out'``: Use `Hyperfan-out`. - ``'harmonic'``: Use the harmonic mean of the `Hyperfan-in` and `Hyperfan-out` init. use_xavier (bool): Whether Kaiming (``False``) or Xavier (``True``) init should be used. uncond_var (float): The variance of unconditional embeddings. This value is only taken into consideration if ``uncond_in_size > 0`` (cf. constructor arguments). cond_var (float): The initial variance of conditional embeddings. This value is only taken into consideration if ``cond_in_size > 0`` (cf. constructor arguments). mnet (mnets.mnet_interface.MainNetInterface, optional): If applicable, the user should provide the main (or target) network, whose weights are generated by this hypernetwork. The ``mnet`` instance is used to extract valuable information that improve the initialization result. If provided, it is assumed that ``target_shapes`` (cf. constructor arguments) corresponds either to :attr:`mnets.mnet_interface.MainNetInterface.param_shapes` or :attr:`mnets.mnet_interface.MainNetInterface.hyper_shapes_learned`. w_val (list or dict, optional): The mean of the distribution with which output head weight matrices are initialized. Note, each weight tensor prescribed by :attr:`hnets.hnet_interface.HyperNetInterface.target_shapes` is produced via an independent linear output head. One may either specify a list of numbers having the same length as :attr:`hnets.hnet_interface.HyperNetInterface.target_shapes` or specify a dictionary which may have as keys the tensor names occurring in :attr:`mnets.mnet_interface.MainNetInterface.param_shapes_meta` and the corresponding mean value for the weight matrices of all output heads producing this type of tensor. If a list is provided, entries may be ``None`` and if a dictionary is provided, not all types of parameter tensors need to be specified. For tensors, for which no value is specified, the default value will be used. The default values for tensor types ``'weight'`` and ``'bias'`` are calculated based on the proposed hyperfan-initialization. For other tensor types the actual hypernet outputs should be drawn from the following distributions - ``'bn_scale'``: :math:`w \sim \delta(w - 1)` - ``'bn_shift'``: :math:`w \sim \delta(w)` - ``'cm_scale'``: :math:`w \sim \delta(w - 1)` - ``'cm_shift'``: :math:`w \sim \delta(w)` - ``'embedding'``: :math:`w \sim \mathcal{N}(0, 1)` Which would correspond to the following passed arguments .. code-block:: python w_val = { 'bn_scale': 0, 'bn_shift': 0, 'cm_scale': 0, 'cm_shift': 0, 'embedding': 0 } w_var = { 'bn_scale': 0, 'bn_shift': 0, 'cm_scale': 0, 'cm_shift': 0, 'embedding': 0 } b_val = { 'bn_scale': 1, 'bn_shift': 0, 'cm_scale': 1, 'cm_shift': 0, 'embedding': 0 } b_var = { 'bn_scale': 0, 'bn_shift': 0, 'cm_scale': 0, 'cm_shift': 0, 'embedding': 1 } w_var (list or dict, optional): The variance of the distribution with which output head weight matrices are initialized. Variance values of zero means that weights are set to a constant defined by ``w_val``. See description of argument ``w_val`` for more details. b_val (list or dict, optional): The mean of the distribution with which output head bias vectors are initialized. See description of argument ``w_val`` for more details. b_var (list or dict, optional): The variance of the distribution with which output head bias vectors are initialized. See description of argument ``w_val`` for more details. """ if method not in ['in', 'out', 'harmonic']: raise ValueError('Invalid value "%s" for argument "method".' % method) if self.unconditional_params is None: assert self._no_uncond_weights raise ValueError('Hypernet without internal weights can\'t be ' + 'initialized.') ### Extract meta-information about target shapes ### meta = None if mnet is not None: assert isinstance(mnet, MainNetInterface) try: meta = mnet.param_shapes_meta except: meta = None if meta is not None: if len(self.target_shapes) == len(mnet.param_shapes): pass # meta = mnet.param_shapes_meta elif len(self.target_shapes) == len(mnet.hyper_shapes_learned): meta = [] for ii in mnet.hyper_shapes_learned_ref: meta.append(mnet.param_shapes_meta[ii]) else: warn('Target shapes of this hypernetwork could not be ' + 'matched to the meta information provided to the ' + 'initialization.') meta = None # TODO If the user doesn't (or can't) provide an `mnet` instance, we # should alternatively allow him to pass meta information directly. if meta is None: meta = [] # Heuristical approach to derive meta information from given shapes. layer_ind = 0 for i, s in enumerate(self.target_shapes): curr_meta = dict() if len(s) > 1: curr_meta['name'] = 'weight' curr_meta['layer'] = layer_ind layer_ind += 1 else: # just a heuristic, we can't know curr_meta['name'] = 'bias' if i > 0 and meta[-1]['name'] == 'weight': curr_meta['layer'] = meta[-1]['layer'] else: curr_meta['layer'] = -1 meta.append(curr_meta) assert len(meta) == len(self.target_shapes) # Mapping from layer index to the corresponding shape. layer_shapes = dict() # Mapping from layer index to whether the layer has a bias vector. layer_has_bias = defaultdict(lambda: False) for i, m in enumerate(meta): if m['name'] == 'weight' and m['layer'] != -1: assert len(self.target_shapes[i]) > 1 layer_shapes[m['layer']] = self.target_shapes[i] if m['name'] == 'bias' and m['layer'] != -1: layer_has_bias[m['layer']] = True ### Compute input variance ### cond_dim = self._cond_in_size uncond_dim = self._uncond_in_size inp_dim = cond_dim + uncond_dim input_variance = 0 if cond_dim > 0: input_variance += (cond_dim / inp_dim) * cond_var if uncond_dim > 0: input_variance += (uncond_dim / inp_dim) * uncond_var ### Initialize hidden layers to preserve variance ### # Note, if batchnorm layers are used, they will simply be initialized to # have no effect after initialization. This does not effect the # performed whitening operation. if self.batchnorm_layers is not None: for bn_layer in self.batchnorm_layers: if hasattr(bn_layer, 'scale'): nn.init.ones_(bn_layer.scale) if hasattr(bn_layer, 'bias'): nn.init.zeros_(bn_layer.bias) # Since batchnorm layers whiten the statistics of hidden # acitivities, the variance of the input will not be preserved by # Xavier/Kaiming. if len(self.batchnorm_layers) > 0: input_variance = 1. # We initialize biases with 0 (see Xavier assumption 4 in the Hyperfan # paper). Otherwise, we couldn't ignore the biases when computing the # output variance of a layer. # Note, we have to use fan-in init for the hidden layer to ensure the # property, that we preserve the input variance. assert len(self._layers) + 1 == len(self.layer_weight_tensors) for i, w_tensor in enumerate(self.layer_weight_tensors[:-1]): if use_xavier: iutils.xavier_fan_in_(w_tensor) else: torch.nn.init.kaiming_uniform_(w_tensor, mode='fan_in', nonlinearity='relu') if self.has_bias: nn.init.zeros_(self.layer_bias_vectors[i]) ### Define default parameters of weight init distributions ### w_val_list = [] w_var_list = [] b_val_list = [] b_var_list = [] for i, m in enumerate(meta): def extract_val(user_arg): curr = None if isinstance(user_arg, (list, tuple)) and \ user_arg[i] is not None: curr = user_arg[i] elif isinstance(user_arg, (dict)) and \ m['name'] in user_arg.keys(): curr = user_arg[m['name']] return curr curr_w_val = extract_val(w_val) curr_w_var = extract_val(w_var) curr_b_val = extract_val(b_val) curr_b_var = extract_val(b_var) if m['name'] == 'weight' or m['name'] == 'bias': if None in [curr_w_val, curr_w_var, curr_b_val, curr_b_var]: # If distribution not fully specified, then we just fall # back to hyper-fan init. curr_w_val = None curr_w_var = None curr_b_val = None curr_b_var = None else: assert m['name'] in ['bn_scale', 'bn_shift', 'cm_scale', 'cm_shift', 'embedding'] if curr_w_val is None: curr_w_val = 0 if curr_w_var is None: curr_w_var = 0 if curr_b_val is None: curr_b_val = 1 if m['name'] in ['bn_scale', 'cm_scale'] \ else 0 if curr_b_var is None: curr_b_var = 1 if m['name'] in ['embedding'] else 0 w_val_list.append(curr_w_val) w_var_list.append(curr_w_var) b_val_list.append(curr_b_val) b_var_list.append(curr_b_var) ### Initialize output heads ### # Note, that all output heads are realized internally via one large # fully-connected layer. # All output heads are linear layers. The biases of these linear # layers (called gamma and beta in the paper) are simply initialized # to zero. Note, that we allow deviations from this below. if self.has_bias: nn.init.zeros_(self.layer_bias_vectors[-1]) c_relu = 1 if use_xavier else 2 # We are not interested in the fan-out, since the fan-out is just # the number of elements in the main network. # `fan-in` is called `d_k` in the paper and is just the size of the # last hidden layer. fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(\ self.layer_weight_tensors[-1]) s_ind = 0 for i, out_shape in enumerate(self.target_shapes): m = meta[i] e_ind = s_ind + int(np.prod(out_shape)) curr_w_val = w_val_list[i] curr_w_var = w_var_list[i] curr_b_val = b_val_list[i] curr_b_var = b_var_list[i] if curr_w_val is None: c_bias = 2 if layer_has_bias[m['layer']] else 1 if m['name'] == 'bias': m_fan_out = out_shape[0] # NOTE For the hyperfan-out init, we also need to know the # fan-in of the layer. if m['layer'] != -1: m_fan_in, _ = iutils.calc_fan_in_and_out( \ layer_shapes[m['layer']]) else: # FIXME Quick-fix. m_fan_in = m_fan_out var_in = c_relu / (2. * fan_in * input_variance) num = c_relu * (1. - m_fan_in/m_fan_out) denom = fan_in * input_variance var_out = max(0, num / denom) else: assert m['name'] == 'weight' m_fan_in, m_fan_out = iutils.calc_fan_in_and_out(out_shape) var_in = c_relu / (c_bias * m_fan_in * fan_in * \ input_variance) var_out = c_relu / (m_fan_out * fan_in * input_variance) if method == 'in': var = var_in elif method == 'out': var = var_out elif method == 'harmonic': var = 2 * (1./var_in + 1./var_out) else: raise ValueError('Method %s invalid.' % method) # Initialize output head weight tensor using `var`. std = math.sqrt(var) a = math.sqrt(3.0) * std torch.nn.init._no_grad_uniform_( \ self.layer_weight_tensors[-1][s_ind:e_ind, :], -a, a) else: if curr_w_var == 0: nn.init.constant_( self.layer_weight_tensors[-1][s_ind:e_ind, :], curr_w_val) else: std = math.sqrt(curr_w_var) a = math.sqrt(3.0) * std torch.nn.init._no_grad_uniform_( \ self.layer_weight_tensors[-1][s_ind:e_ind, :], curr_w_val-a, curr_w_val+a) if curr_b_var == 0: nn.init.constant_( self.layer_bias_vectors[-1][s_ind:e_ind], curr_b_val) else: std = math.sqrt(curr_b_var) a = math.sqrt(3.0) * std torch.nn.init._no_grad_uniform_( \ self.layer_bias_vectors[-1][s_ind:e_ind], curr_b_val-a, curr_b_val+a) s_ind = e_ind
[docs] def get_cond_in_emb(self, cond_id): """Get the ``cond_id``-th (conditional) input embedding. Args: cond_id (int): Determines which input embedding should be returned (the ID has to be between ``0`` and ``num_cond_embs-1``, where ``num_cond_embs`` denotes the corresponding constructor argument). Returns: (torch.nn.Parameter) """ if self.conditional_params is None: raise RuntimeError('Input embeddings are not internally ' + 'maintained!') if not isinstance(cond_id, int) or cond_id < 0 or \ cond_id >= len(self.conditional_params): raise RuntimeError('Option "cond_id" must be between 0 and %d!' \ % (len(self.conditional_params)-1)) return self.conditional_params[cond_id]
if __name__ == '__main__': pass