Source code for hypnettorch.hnets.chunked_mlp_hnet

#!/usr/bin/env python3
# Copyright 2020 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title          :hnets/chunked_mlp_hnet.py
# @author         :ch
# @contact        :henningc@ethz.ch
# @created        :04/15/2020
# @version        :1.0
# @python_version :3.6.10
r"""
Chunked MLP - Hypernetwork
--------------------------

The module :mod:`hnets.chunked_mlp_hnet` contains a `Chunked Hypernetwork`, that
uses a full hypernetwork (see :class:`hnets.mlp_hnet.HMLP`) to produce one
chunk of the output weights at a time.

The hypernetwork :math:`h_\theta(e)` (with input :math:`e`) operates as follows.
The target outputs  (see
:attr:`hnets.hnet_interface.HyperNetInterface.target_shapes`) are flattened and
split into equally sized chunks. Those chunks are separately generated by an
internal full hypernetwork :math:`h'_{\theta'}(e,c)` (that is hidden from the
user), where :math:`c` denotes the chunk embedding, which are internally
maintained and chunk-specific.

Note:
    This type of hypernetwork is completely agnostic to the architecture of the
    target network. The splits happen at arbitrary locations in the flattened
    target network weight vector.
"""
from collections import defaultdict
import math
import numpy as np
import torch
import torch.nn as nn
from warnings import warn

from hypnettorch.hnets.hnet_interface import HyperNetInterface
from hypnettorch.hnets.mlp_hnet import HMLP
from hypnettorch.mnets.mnet_interface import MainNetInterface
from hypnettorch.utils import init_utils as iutils

[docs]class ChunkedHMLP(nn.Module, HyperNetInterface): """Implementation of a `chunked fully-connected hypernet`. The ``target_shapes`` will be flattened and split into chunks of size ``chunk_size``. In total, there will be ``np.ceil(self.num_outputs/chunk_size)`` chunks, where the last chunk produced might contain a remainder that is discarded. Each chunk has it's own `chunk embedding` that is fed into the underlying hypernetwork. Note: It is possible to set ``uncond_in_size`` and ``cond_in_size`` to zero if ``cond_chunk_embs`` is ``True``. Args: (....): See constructor arguments of class :class:`hnets.mlp_hnet.HMLP`. chunk_size (int): The chunk size, i.e, the number of weights produced by individual forward passes of the internally maintained instance of a full hypernet (see :class:`hnets.mlp_hnet.HMLP`) upon receiving a chunk embedding). chunk_emb_size (int): The size of a chunk embedding. cond_chunk_embs (bool): Whether chunk embeddings are unconditional (``False``) or conditional (``True``) parameters. See constructor argument ``cond_chunk_embs``. Note: Embeddings will be initialized with a normal distribution using zero mean and unit variance. cond_chunk_embs (bool): Consider chunk embeddings to be conditional. In this case, there will be a different set of chunk embeddings per condition (specified via ``num_cond_embs``). If ``False``, there will be a total of :attr:`num_chunks` chunk embeddings that are maintained within :attr:`hnets.hnet_interface.\ HyperNetInterface.unconditional_param_shapes`. If ``True``, there will be ``num_cond_embs * self.num_chunks`` chunk embeddings that are maintained within :attr:`hnets.hnet_interface.\ HyperNetInterface.conditional_param_shapes`. However, if ``num_cond_embs == 0``, then chunk embeddings have to be provided in a special way to the :meth:`forward` method (see the corresponding argument ``weights``). """ def __init__(self, target_shapes, chunk_size, chunk_emb_size=8, cond_chunk_embs=False, uncond_in_size=0, cond_in_size=8, layers=(100, 100), verbose=True, activation_fn=torch.nn.ReLU(), use_bias=True, no_uncond_weights=False, no_cond_weights=False, num_cond_embs=1, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) HyperNetInterface.__init__(self) assert isinstance(chunk_size, int) and chunk_size > 0 assert isinstance(chunk_emb_size, int) and chunk_emb_size > 0 ### Make constructor arguments internally available ### self._chunk_size = chunk_size self._chunk_emb_size = chunk_emb_size self._cond_chunk_embs = cond_chunk_embs self._uncond_in_size = uncond_in_size self._cond_in_size = cond_in_size self._no_uncond_weights = no_uncond_weights self._no_cond_weights = no_cond_weights self._num_cond_embs = num_cond_embs ### Create underlying full hypernet ### # Note, even if chunk embeddings are considered conditional, they # are maintained in this object and just fed as an external input to the # underlying hnet. hnet_uncond_in_size = uncond_in_size + chunk_emb_size hnet_num_cond_embs = num_cond_embs if cond_chunk_embs and num_cond_embs == 0: raise ValueError('Conditional chunk embeddings can only be used ' + 'if conditions are known to the hypernetwork!') if cond_chunk_embs and cond_in_size == 0: # If there are no other conditional embeddings except the chunk # embeddings, we tell the underlying hnet explicitly that it doesn't # need to maintain any conditional weights to avoid that it will # throw a warning. hnet_num_cond_embs = 0 self._hnet = HMLP([[chunk_size]], uncond_in_size=hnet_uncond_in_size, cond_in_size=cond_in_size, layers=layers, verbose=False, activation_fn=activation_fn, use_bias=use_bias, no_uncond_weights=no_uncond_weights, no_cond_weights=no_cond_weights, num_cond_embs=hnet_num_cond_embs, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm) ### Setup attributes required by interface ### # Most of these attributes are taken over from `self._hnet` self._target_shapes = target_shapes self._num_known_conds = self._num_cond_embs self._unconditional_param_shapes_ref = \ list(self._hnet._unconditional_param_shapes_ref) if self._hnet._internal_params is not None: self._internal_params = \ nn.ParameterList(self._hnet._internal_params) self._param_shapes = list(self._hnet._param_shapes) self._param_shapes_meta = list(self._hnet._param_shapes_meta) if self._hnet._hyper_shapes_learned is not None: self._hyper_shapes_learned = list(self._hnet._hyper_shapes_learned) self._hyper_shapes_learned_ref = \ list(self._hnet._hyper_shapes_learned_ref) if self._hnet._hyper_shapes_distilled is not None: self._hyper_shapes_distilled = \ list(self._hnet._hyper_shapes_distilled) self._has_bias = self._hnet._has_bias self._has_fc_out = self._hnet._has_fc_out # Just to make that clear explicitly. We will additionally append # the chunk embeddings at the end of `param_shapes`. # We don't prepend it to the beginning, to keep conditional input # embeddings at the beginning. self._mask_fc_out = False self._has_linear_out = self._hnet._has_linear_out self._layer_weight_tensors = \ nn.ParameterList(self._hnet._layer_weight_tensors) self._layer_bias_vectors = \ nn.ParameterList(self._hnet._layer_bias_vectors) if self._hnet._batchnorm_layers is not None: self._batchnorm_layers = nn.ModuleList(self._hnet._batchnorm_layers) if self._hnet._context_mod_layers is not None: self._context_mod_layers = \ nn.ModuleList(self._hnet._context_mod_layers) ### Create chunk embeddings ### if cond_in_size == 0 and uncond_in_size == 0 and not cond_chunk_embs: # Note, we could also allow this case. It would be analoguous to # creating a full hypernet with no unconditional input and one # conditional embedding. But the user can explicitly achieve that # as noted below. raise ValueError('If no external (conditional or unconditional) ' + 'input is provided to the hypernetwork, then ' + 'it can only learn a fixed output. If this ' + 'behavior is desired, please enable ' + '"cond_chunk_embs" and set "num_cond_embs=1".') num_cemb_mats = 1 no_cemb_weights = no_uncond_weights if cond_chunk_embs: num_cemb_mats = num_cond_embs no_cemb_weights = no_cond_weights self._cemb_shape = [self.num_chunks, chunk_emb_size] for _ in range(num_cemb_mats): if not no_cemb_weights: self._internal_params.append(nn.Parameter( \ data=torch.Tensor(*self._cemb_shape), requires_grad=True)) torch.nn.init.normal_(self._internal_params[-1], mean=0., std=1.) else: self._hyper_shapes_learned.append(self._cemb_shape) self._hyper_shapes_learned_ref.append(len(self.param_shapes)) if not cond_chunk_embs: self._unconditional_param_shapes_ref.append( \ len(self.param_shapes)) self._param_shapes.append(self._cemb_shape) # In principle, these embeddings also belong to the input, so we # just assign them as "layer" 0 (note, the underlying hnet uses the # same layer ID for its embeddings. self._param_shapes_meta.append({ 'name': 'embedding', 'index': -1 if no_cemb_weights else \ len(self._internal_params)-1, 'layer': 0, 'info': 'chunk embeddings' }) ### Finalize construction ### self._is_properly_setup() if verbose: print('Created Chunked MLP Hypernet with %d chunk(s) of size %d.' \ % (self.num_chunks, chunk_size)) print(self) @property def num_chunks(self): """The number of chunks that make up the final hypernet output. This also corresponds to the number of chunk embeddings required per forward sweep. :type: int """ return int(np.ceil(self.num_outputs / self._chunk_size)) @property def chunk_emb_size(self): """See constructor argument ``chunk_emb_size``.""" return self._chunk_emb_size @property def cond_chunk_embs(self): """See constructor argument ``cond_chunk_embs``.""" return self._cond_chunk_embs
[docs] def forward(self, uncond_input=None, cond_input=None, cond_id=None, weights=None, distilled_params=None, condition=None, ret_format='squeezed'): """Compute the weights of a target network. Args: (....): See docstring of method :meth:`hnets.mlp_hnet.HMLP.forward`. weights (list or dict, optional): If provided as ``dict`` and chunk embeddings are considered conditional (see constructor argument ``cond_chunk_embs``), then the additional key ``chunk_embs`` can be used to pass a batch of chunk embeddings. This option is mutually exclusive with the option of passing ``cond_id``. Note, if conditional inputs via ``cond_input`` are expected, then the batch sizes must agree. A batch of chunk embeddings is expected to be tensor of shape ``[B, num_chunks, chunk_emb_size]``, where ``B`` denotes the batch size. Returns: (list or torch.Tensor): See docstring of method :meth:`hnets.hnet_interface.HyperNetInterface.forward`. """ cond_chunk_embs = None if isinstance(weights, dict): if 'chunk_embs' in weights.keys(): cond_chunk_embs = weights['chunk_embs'] if not self._cond_chunk_embs: raise ValueError('Key "chunk_embs" for argument ' + '"weights" is only allowed if chunk ' + 'embeddings are conditional.') assert len(cond_chunk_embs.shape) == 3 and \ np.all(np.equal(cond_chunk_embs.shape[1:], [self.num_chunks, self.chunk_emb_size])) if cond_id is not None: raise ValueError('Option "cond_id" is mutually exclusive ' + 'with key "chunk_embs" for argument ' + '"weights".') assert cond_input is None or \ cond_input.shape[0] == cond_chunk_embs.shape[0] # Remove `chunk_embs` from dictionary, since upper class parser # doesn't know how to deal with it. del weights['chunk_embs'] if len(weights.keys()) == 0: # Empty dictionary. weights = None if cond_input is not None and self._cond_chunk_embs and \ cond_chunk_embs is None: raise ValueError('Conditional chunk embeddings have to be ' + 'provided via "weights" if "cond_input" is ' + 'specified.') _input_required = self._cond_in_size > 0 or self._uncond_in_size > 0 # We parse `cond_id` afterwards if chunk embeddings are also # conditional. if self._cond_chunk_embs: _parse_cond_id_fct = lambda x, y, z: None else: _parse_cond_id_fct = None uncond_input, cond_input, uncond_weights, cond_weights = \ self._preprocess_forward_args(_input_required=_input_required, _parse_cond_id_fct=_parse_cond_id_fct, uncond_input=uncond_input, cond_input=cond_input, cond_id=cond_id, weights=weights, distilled_params=distilled_params, condition=condition, ret_format=ret_format) ### Translate IDs to conditional inputs ### if cond_id is not None and self._cond_chunk_embs: assert cond_input is None and cond_chunk_embs is None cond_id = [cond_id] if isinstance(cond_id, int) else cond_id if cond_weights is None: raise ValueError('Forward option "cond_id" can only be ' + 'used if conditional parameters are ' + 'maintained internally or passed to the ' + 'forward method via option "weights".') cond_chunk_embs = [] cond_input = [] if self._cond_in_size > 0 else None for i, cid in enumerate(cond_id): if cid < 0 or cid >= self._num_cond_embs: raise ValueError('Condition %d not existing!' % (cid)) # Note, we do not necessarily have conditional embeddings. if self._cond_in_size > 0: cond_input.append(cond_weights[cid]) cond_chunk_embs.append( \ cond_weights[-self._num_cond_embs+cid]) if self._cond_in_size > 0: cond_input = torch.stack(cond_input, dim=0) cond_chunk_embs = torch.stack(cond_chunk_embs, dim=0) ### Assemble hypernetwork input ### batch_size = None if cond_input is not None: batch_size = cond_input.shape[0] if cond_chunk_embs is not None: assert batch_size is None or batch_size == cond_chunk_embs.shape[0] batch_size = cond_chunk_embs.shape[0] if uncond_input is not None: if batch_size is None: batch_size = uncond_input.shape[0] else: assert batch_size == uncond_input.shape[0] assert batch_size is not None chunk_embs = None if self._cond_chunk_embs: assert cond_chunk_embs is not None and \ len(cond_chunk_embs.shape) == 3 assert self._cond_in_size == 0 or cond_input is not None chunk_embs = cond_chunk_embs else: assert cond_chunk_embs is None chunk_embs = uncond_weights[-1] # Insert batch dimension. chunk_embs = chunk_embs.expand(batch_size, self.num_chunks, self.chunk_emb_size) # We now have the following setup: # cond_input: [batch_size, cond_in_size] or None # uncond_input: [batch_size, uncond_in_size] or None # chunk_embs: [batch_size, num_chunks, chunk_emb_size] # We now first copy the hypernet inputs for each chunk, arriving at # cond_input: [batch_size, num_chunks, cond_in_size] or None # uncond_input: [batch_size, num_chunks, uncond_in_size] or None if cond_input is not None: cond_input = cond_input.reshape(batch_size, 1, -1) cond_input = cond_input.expand(batch_size, self.num_chunks, self._cond_in_size) if uncond_input is not None: uncond_input = uncond_input.reshape(batch_size, 1, -1) uncond_input = uncond_input.expand(batch_size, self.num_chunks, self._uncond_in_size) # The chunk embeddings are considered unconditional inputs to the # underlying hypernetwork. uncond_input = torch.cat([uncond_input, chunk_embs], dim=2) else: uncond_input = chunk_embs # Now we build one big batch for the underlying hypernetwork, with # batch size: batch_size * num_chunks. if cond_input is not None: cond_input = cond_input.reshape(batch_size * self.num_chunks, -1) uncond_input = uncond_input.reshape(batch_size * self.num_chunks, -1) ### Weight of underlying hypernetwork ### weights = dict() if cond_weights is not None and self._cond_chunk_embs: weights['cond_weights'] = cond_weights[:-self._num_cond_embs] elif cond_weights is not None: weights['cond_weights'] = cond_weights assert uncond_weights is not None if self._cond_chunk_embs: weights['uncond_weights'] = uncond_weights else: weights['uncond_weights'] = uncond_weights[:-1] ### Process chunks ### hnet_out = self._hnet.forward(uncond_input=uncond_input, cond_input=cond_input, cond_id=None, weights=weights, distilled_params=distilled_params, condition=condition, ret_format='flattened') assert np.all(np.equal(hnet_out.shape, [batch_size * self.num_chunks, self._chunk_size])) # FIXME We can skip this line, right? hnet_out = hnet_out.view(batch_size, self.num_chunks, self._chunk_size) # Concatenate individual chunks. hnet_out = hnet_out.view(batch_size, self.num_chunks * self._chunk_size) # Throw away unused part of last chunk. hnet_out = hnet_out[:, :self.num_outputs] ### Assemble hypernet output ### ret = self._flat_to_ret_format(hnet_out, ret_format) return ret
[docs] def distillation_targets(self): """Targets to be distilled after training. See docstring of abstract super method :meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`. Returns: See :meth:`hnets.mlp_hnet.HMLP.distillation_targets`. """ # We don't have any additional distillation targets. We also just pass # `distilled_params` to the underlying hypernetwork in the `forward` # method. return self._hnet.distillation_targets
[docs] def apply_chunked_hyperfan_init(self, method='in', use_xavier=False, uncond_var=1., cond_var=1., eps=1e-5, cemb_normal_init=False, mnet=None, target_vars=None): r"""Initialize the network using a chunked hyperfan init. Inspired by the method `Hyperfan Init <https://openreview.net/forum?id=H1lma24tPB>`__ which we implemented for the MLP hypernetwork in method :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`, we heuristically developed a better initialization method for chunked hypernetworks. Unfortunately, the `Hyperfan Init` method from the paper does not apply to this kind of hypernetwork, since we reuse the same hypernet output head for the whole main network. Luckily, we can provide a simple heuristic. Similar to `Meyerson & Miikkulainen <https://arxiv.org/abs/1906.00097>`__ we play with the variance of the input embeddings to affect the variance of the output weights. In a chunked hypernetwork, the input for each chunk is identical except for the chunk embeddings :math:`\mathbf{c}`. Let :math:`\mathbf{e}` denote the remaining inputs to the hypernetwork, which are identical for all chunks. Then, assuming the hypernetwork was initialized via fan-in init, the variance of the hypernetwork output :math:`\mathbf{v}` can be written as follows (see documentation of method :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`): .. math:: \text{Var}(v) = \frac{n_e}{n_e+n_c} \text{Var}(e) + \ \frac{n_c}{n_e+n_c} \text{Var}(c) Hence, we can achieve a desired output variance :math:`\text{Var}(v)` by initializing the chunk embeddings :math:`\mathbf{c}` via the following variance: .. math:: \text{Var}(c) = \max \Big\{ 0, \ \frac{1}{n_c} \big[ (n_e+n_c) \text{Var}(v) - \ n_e \text{Var}(e) \big] \Big\} Now, one important question remains. How do we pick a desired output variance :math:`\text{Var}(v)` for a chunk? Note, a chunk may include weights from several layers. The likelihood for this to happen depends on the main net architecture and the chunk size (see constructor argument ``chunk_size``). The smaller the chunk size, the less likely it is that a chunk will contain elements from multiple main net weight tensors. In case each chunk would contain only weights from one main net weight tensor, we could simply pick the variance :math:`\text{Var}(v)` that would have been chosen by a main net initialization method (such as Xavier). In case a chunk contains contributions from several main net weight tensors, we apply the following heuristic. If a chunk contains contributions of a set of main network weight tensors :math:`W_1, \dots, W_K` with relative contribution sizes\ :math:`n_1, \dots, n_K` such that :math:`n_1 + \dots + n_K = n_v` where :math:`n_v` denotes the chunk size and if the corresponding main network initialization method would require init variances :math:`\text{Var}(w_1), \dots, \text{Var}(w_K)`, then we simply request a weighted average as follow: .. math:: \text{Var}(v) = \frac{1}{n_v} \sum_{k=1}^K n_k \text{Var}(w_k) What about bias vectors? Usually, the variance analysis applied to Xavier or Kaiming init assumes that biases are initialized to zero. This is not possible in this setting, as it would require assigning a negative variance to :math:`\mathbf{c}`. Instead, we follow the default PyTorch initialization (e.g., see method ``reset_parameters`` in class :class:`torch.nn.Linear`). There, bias vectors are initialized uniformly within a range of :math:`\pm \frac{1}{\sqrt{f_{\text{in}}}}` where :math:`f_{\text{in}}` refers to the fan-in of the layer. This type of initialization corresponds to a variance of :math:`\text{Var}(v) = \frac{1}{3 f_{\text{in}}}`. Note: All hypernet inputs are assumed to be zero-mean random variables. Note: To avoid that the variances with which chunks are initialized have to be clipped (because they are too small or even negative), the variance of the remaining hypernet inputs should be properly scaled. In general, one should adhere the following rule .. math:: \text{Var}(e) < \frac{n_e+n_c}{n_e} \text{Var}(v) This method will calculate and print the maximum value that should be chosen for :math:`\text{Var}(e)` and will print warnings if variances have to be clipped. Args: (....): See arguments of method :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`. method (str): The type of initialization that should be applied. Possible options are: - ``in``: Use `Chunked Hyperfan-in`, i.e., rather the output variances of the hypernetwork should correspond to fan-in variances. - ``out``: Use `Chunked Hyperfan-out`, i.e., rather the output variances of the hypernetwork should correspond to fan-out variances. - ``harmonic``: Use the harmonic mean of the fan-in and fan-out variance as target variance of the hypernetwork output. eps (float): The minimum variance with which a chunk embedding is initialized. cemb_normal_init (bool): Use normal init for chunk embeddings rather than uniform init. target_vars (list or dict, optional): The variance of the distribution for each parameter tensor generated by this hypernetwork. Target variance values can either be provided as list of length ``len(hnet.target_shapes)`` or as dictionary. The usage is analoguous to the usage of parameter ``w_val`` of method :meth:`hnets.mlp_hnet.HMLP.apply_hyperfan_init`. Note: This method currently does not allow initial output distributions with non-zero mean. However, the docstring of method :meth:`probabilistic.gauss_hnet_init.gauss_hyperfan_init` describes how this is in principle feasible and might be incorporated in the future. Note: Unspecified target variances for parameter tensors of type ``'weight'`` or ``'bias'`` are computed as described above. Default target variances for all other parameter tensor types are simply ``1``. """ if method not in ['in', 'out', 'harmonic']: raise ValueError('Invalid value "%s" for argument "method".' % method) if self.unconditional_params is None: assert self._no_uncond_weights raise ValueError('Hypernet without internal weights can\'t be ' + 'initialized.') if self.unconditional_params is None and self._cond_chunk_embs: assert self._no_cond_weights raise ValueError('Chunked hyperfan init cannot be applied if ' + 'chunk embeddings are not internally maintained.') ### Extract meta-information about target shapes ### # FIXME This section is copied from the HMLP implementation. meta = None if mnet is not None: assert isinstance(mnet, MainNetInterface) try: meta = mnet.param_shapes_meta except: meta = None if meta is not None: if len(self.target_shapes) == len(mnet.param_shapes): pass # meta = mnet.param_shapes_meta elif len(self.target_shapes) == len(mnet.hyper_shapes_learned): meta = [] for ii in mnet.hyper_shapes_learned_ref: meta.append(mnet.param_shapes_meta[ii]) else: warn('Target shapes of this hypernetwork could not be ' + 'matched to the meta information provided to the ' + 'initialization.') meta = None # TODO If the user doesn't (or can't) provide an `mnet` instance, we # should alternatively allow him to pass meta information directly. if meta is None: meta = [] # Heuristical approach to derive meta information from given shapes. layer_ind = 0 for i, s in enumerate(self.target_shapes): curr_meta = dict() if len(s) > 1: curr_meta['name'] = 'weight' curr_meta['layer'] = layer_ind layer_ind += 1 else: # just a heuristic, we can't know curr_meta['name'] = 'bias' if i > 0 and meta[-1]['name'] == 'weight': curr_meta['layer'] = meta[-1]['layer'] else: curr_meta['layer'] = -1 meta.append(curr_meta) assert len(meta) == len(self.target_shapes) # Mapping from layer index to the corresponding shape. layer_shapes = dict() # Mapping from layer index to whether the layer has a bias vector. layer_has_bias = defaultdict(lambda: False) for i, m in enumerate(meta): if m['name'] == 'weight' and m['layer'] != -1: assert len(self.target_shapes[i]) > 1 layer_shapes[m['layer']] = self.target_shapes[i] if m['name'] == 'bias' and m['layer'] != -1: layer_has_bias[m['layer']] = True ### Compute input variance ### # The input variance does not include the variance of chunk embeddings! # Instead, it is the variance of the inputs that are shared across all # chunks. cond_dim = self._cond_in_size uncond_dim = self._uncond_in_size # Note, `inp_dim` can be zero if conditional chunk embeddings are used. inp_dim = cond_dim + uncond_dim inp_var = 0 if cond_dim > 0: inp_var += (cond_dim / inp_dim) * cond_var if uncond_dim > 0: inp_var += (uncond_dim / inp_dim) * uncond_var c_dim = self.chunk_emb_size ### Initialize hypernet with fan-in init ### if self.batchnorm_layers is not None and len(self.batchnorm_layers) > 0: # Note, batchnorm layers simply whiten the incoming statistics. # Thus, if we tune the variance of chunk embeddings, this variance # is normalized by a batchnorm layer and thus vanishes. raise RuntimeError('Chunked hyperfan init not applicable if a ' + 'hypernetwork with batchnorm layers is used.') # Note, the whole internal hypernetwork is initialized with fan-in init # to simply pass the variance of all inputs to the hypernet output. for i, w_tensor in enumerate(self.layer_weight_tensors): if use_xavier: iutils.xavier_fan_in_(w_tensor) else: torch.nn.init.kaiming_uniform_(w_tensor, mode='fan_in', nonlinearity='relu') if self.has_bias: nn.init.zeros_(self.layer_bias_vectors[i]) ### Compute target variance of each output tensor ### if target_vars is None: target_vars = [None] * len(self.target_shapes) elif isinstance(target_vars, dict): target_vars_d = target_vars target_vars = [] for i, m in enumerate(meta): if m['name'] in target_vars_d.keys(): target_vars.append(target_vars_d[m['name']]) else: target_vars.append(None) else: assert isinstance(target_vars, (list, tuple)) assert len(target_vars) == len(self.target_shapes) for i, s in enumerate(self.target_shapes): if target_vars[i] is not None: # Use user specified target variance. continue m = meta[i] if m['name'] == 'bias': if m['layer'] != -1: fan_in, _ = iutils.calc_fan_in_and_out( \ layer_shapes[m['layer']]) else: # FIXME Quick-fix, use fan-out instead. fan_in = s[0] target_vars[i] = 1. / (3. * fan_in) elif m['name'] == 'weight': fan_in, fan_out = iutils.calc_fan_in_and_out(s) c_relu = 1 if use_xavier else 2 var_in = c_relu / fan_in var_out = c_relu / fan_out if method == 'in': var = var_in elif method == 'out': var = var_out else: var = 2 * (1./var_in + 1./var_out) target_vars[i] = var else: target_vars[i] = 1. ### Target variance per chunk ### chunk_vars = [] i = 0 n = np.prod(self.target_shapes[i]) for j in range(self.num_chunks): m = self._chunk_size var = 0 while m > 0: # Special treatment to fill up last chunk. if j == self.num_chunks-1 and i == len(target_vars)-1: assert n <= m o = m else: o = min(m, n) var += o / self._chunk_size * target_vars[i] m -= o n -= o if n == 0: i += 1 if i < len(target_vars): n = np.prod(self.target_shapes[i]) chunk_vars.append(var) if inp_dim > 0: max_inp_var = (inp_dim+c_dim) / inp_dim * min(chunk_vars) max_inp_std = math.sqrt(max_inp_var) print('Initializing hypernet with Chunked Hyperfan Init ...') if inp_var >= max_inp_var: warn('Note, hypernetwork inputs should have an initial total ' + 'variance (std) smaller than %f (%f) in order for this ' \ % (max_inp_var, max_inp_std) + 'method to work properly.') ### Compute variances of chunk embeddings ### # We could have done that in the previous loop. But I think the code is # more readible this way. c_vars = [] n_clipped = 0 for i, var in enumerate(chunk_vars): c_var = 1./c_dim * ((inp_dim+c_dim) * var - inp_dim * inp_var) if c_var < eps: n_clipped += 1 #warn('Initial variance of chunk embedding %d has to ' % i + \ # 'be clipped.') c_vars.append(max(eps, c_var)) if n_clipped > 0: warn('Initial variance of %d/%d ' % (n_clipped, len(chunk_vars)) + \ 'chunk embeddings had to be clipped.') ### Initialize chunk embeddings ### for i in range(self.num_chunks): c_std = math.sqrt(c_vars[i]) num_conds = self.num_known_conds if self._cond_chunk_embs else 1 for j in range(num_conds): cond_id = j if self._cond_chunk_embs else None c_emb = self.get_chunk_emb(chunk_id=i, cond_id=cond_id) if cemb_normal_init: torch.nn.init.normal_(c_emb, mean=0, std=c_std) else: a = math.sqrt(3.0) * c_std torch.nn.init._no_grad_uniform_(c_emb, -a, a)
[docs] def get_cond_in_emb(self, cond_id): """Get the ``cond_id``-th (conditional) input embedding. Args: (....): See docstring of method :meth:`hnets.mlp_hnet.HMLP.get_cond_in_emb`. Returns: (torch.nn.Parameter) """ return self._hnet.get_cond_in_emb(cond_id)
[docs] def get_chunk_emb(self, chunk_id=None, cond_id=None): """Get the ``chunk_id``-th chunk embedding. Args: chunk_id (int, optional): A number between 0 and :attr:`num_chunks` - 1. If not specified, a full chunk matrix with shape ``[num_chunks, chunk_emb_size]`` is returned. Otherwise, the ``chunk_id``-th row is returned. cond_id (int): Is mandatory if constructor argument ``cond_chunk_embs`` was set. Determines the set of chunk embeddings to be considered. Returns: (torch.nn.Parameter) """ if self._cond_chunk_embs: if cond_id is None: raise RuntimeError('Option "cond_id" has to be set if chunk ' + 'embeddings are conditional parameters!') if self.conditional_params is None: raise RuntimeError('Conditional chunk embeddings are not ' + 'internally maintained!') if not isinstance(cond_id, int) or cond_id < 0 or \ cond_id >= self._num_cond_embs: raise RuntimeError('Option "cond_id" must be between 0 and ' + '%d!' % (self._num_cond_embs-1)) # Note, the last `self._num_cond_embs` params are chunk embeddings. chunk_embs = self.conditional_params[-self._num_cond_embs+cond_id] else: assert cond_id is None if self.unconditional_params is None: raise RuntimeError('Chunk embeddings are not internally ' + 'maintained!') chunk_embs = self.unconditional_params[-1] if chunk_id is None: return chunk_embs else: if not isinstance(chunk_id, int) or chunk_id < 0 or \ chunk_id >= self.num_chunks: raise RuntimeError('Option "chunk_id" must be between 0 and ' + '%d!' % (self.num_chunks-1)) return chunk_embs[chunk_id, :]
if __name__ == '__main__': pass