Source code for hypnettorch.hnets.hnet_perturbation_wrapper

#!/usr/bin/env python3
# Copyright 2020 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title          :hnets/hnet_perturbation_wrapper.py
# @author         :ch
# @contact        :henningc@ethz.ch
# @created        :12/17/2020
# @version        :1.0
# @python_version :3.6.10
"""
Hypernetwork-wrapper for input-preprocessing and output-postprocessing
----------------------------------------------------------------------

The module :mod:`hnets.hnet_perturbation_wrapper` implements a wrapper for
hypernetworks that implement the interface
:class:`hnets.hnet_interface.HyperNetInterface`. By default, the wrapper is
meant for perturbing hypernetwork outputs, such that an implicit distribution
(realized via a hypernetwork) with low-dimensional support can be inflated to
have support in the full weight space.

However, the wrapper allows in general to pass function handles that preprocess
inputs and/or postprocess hypernetwork outputs.
"""
import torch.nn as nn

from hypnettorch.hnets.hnet_interface import HyperNetInterface

[docs]class HPerturbWrapper(nn.Module, HyperNetInterface): r"""Hypernetwork wrapper for output perturbation. This wrapper is meant as a helper for hypernetworks that represent implicit distributions, i.e., distributions that transform a simple base distribution :math:`p_Z(z)` into a complex target distributions .. math:: w \sim q_{\theta}(W) \Leftrightarrow w = h_{\theta}(z) \quad \text{,} \quad z \sim p_Z(Z) However, the wrapper is more versatile and can also become handy in a variety of other use cases. Yet, in the following we concentrate on implicit distributions and their practical challenges. One main challenge is typically that the density :math:`q_\theta(W)` is only defined on a lower-dimensional manifold of the weight space. This is often an undesirable property (e.g., such implicit distributions are often not amenable for optimization with standard divergence measures, such as the KL). A simple way to overcome this issue is to add noise perturbations to the output of the hypernetwork, such that the perturbations itself origin from a full-support distribution. By default, this hypernetwork wrapper adjusts the sampling procedure above in the following way .. math:: :label: eqdefaultsampling w \sim \tilde{q}_{\theta}(W) \Leftrightarrow w = h_{\theta}(z_{:n}) + \sigma_{\text{noise}}^2 z \equiv \tilde{h}_{\theta}(z) \quad \text{,} \quad z \sim p_Z(Z) where now :math:`\dim(\mathcal{W}) = \dim(\mathcal{Z})`, :math:`\sigma_\text{noise}` is a hyperparameter that controls the perturbation strength, and :math:`z_{:n}` are the :math:`n` first entries of the vector :math:`z`. By default, the unconditional input size of this hypernetwork will be of size ``hnet.num_outputs`` (if ``input_handler`` is not provided) and the output size will be of the same size. Args: hnet (hnets.hnet_interface.HyperNetInterface): The hypernetwork around which this wrapper should be wrapped. hnet_uncond_in_size (int): This argument refers to :math:`n` from Eq. :eq:`eqdefaultsampling`. If ``input_handler`` is provided, this argument will be ignored. sigma_noise (float): The perturbation strength :math:`\sigma_\text{noise}` from Eq. :eq:`eqdefaultsampling`. If ``output_handler`` is provided, this argument will be ignored. input_handler (func, optional): A function handler to process the inputs to the :meth:`hnets.hnet_interface.HyperNetInterface.forward` method of ``hnet``. The function handler should have the following signature .. code-block:: python uncond_input_int, cond_input_int, cond_id_int = input_handler( \ uncond_input=None, cond_input=None, cond_id=None) The returned values will be passed to :attr:`internal_hnet`. Example: For instance, to reproduce the behavior depicted in Eq. :eq:`eqdefaultsampling` one could provide the following handler .. code-block:: python def input_handler(uncond_input=None, cond_input=None, cond_id=None): assert uncond_input is not None n = 5 return uncond_input[:, :n], cond_input, cond_id output_handler (func, optional): A function handler to postprocess the outputs of the internal hypernetwork :attr:`internal_hnet`. A function handler with the following signature is expected. .. code-block:: python hnet_out = output_handler(hnet_out_int, uncond_input=None, cond_input=None, cond_id=None) where ``hnet_out_int`` is the output of the internal hypernetwork :attr:`internal_hnet` and the remaining arguments are the original arguments passed to method :meth:`forward`. ``hnet_out_int`` will always have the format ``ret_format='flattened'`` and is also expected to return this format. Example: Deviating from Eq. :eq:`eqdefaultsampling`, let's say we want to implement the following sampling behavior .. math:: w \sim \hat{q}_\theta(W) \Leftrightarrow w = h_\theta(z) + \epsilon_w \quad \text{,} \quad z \sim p_Z(Z) \text{ and } \epsilon_w \sim p_\text{noise}(W) In this case the unconditional input ``uncond_input`` to the :meth:`forward` method is expected to have size :math:`\dim(\mathcal{Z}) + \dim(\mathcal{W})`. .. code-block:: python def input_handler(uncond_input=None, cond_input=None, cond_id=None): assert uncond_input is not None return uncond_input[:, :dim_z], cond_input, cond_id .. code-block:: python def output_handler(hnet_out_int, uncond_input=None, cond_input=None, cond_id=None): assert uncond_input is not None return hnet_out_int + uncond_input[:, dim_z:] verbose (bool): Whether network information should be printed during network creation. """ def __init__(self, hnet, hnet_uncond_in_size=None, sigma_noise=0.02, input_handler=None, output_handler=None, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) HyperNetInterface.__init__(self) assert isinstance(hnet, HyperNetInterface) self._hnet = hnet self._hnet_uncond_in_size = hnet_uncond_in_size self._sigma_noise = sigma_noise self._input_handler = input_handler self._output_handler = output_handler if input_handler is None and hnet_uncond_in_size is None: raise ValueError('Either "input_handler" or "hnet_uncond_in_size"' + ' has to be specified.') ### Setup attributes required by interface ### # Most of these attributes are taken over from `self._hnet` self._target_shapes = hnet.target_shapes self._num_known_conds = self._hnet.num_known_conds self._unconditional_param_shapes_ref = \ list(self._hnet.unconditional_param_shapes_ref) if self._hnet.internal_params is not None: self._internal_params = \ nn.ParameterList(self._hnet.internal_params) self._param_shapes = list(self._hnet.param_shapes) self._param_shapes_meta = list(self._hnet.param_shapes_meta) if self._hnet.hyper_shapes_learned is not None: self._hyper_shapes_learned = list(self._hnet.hyper_shapes_learned) self._hyper_shapes_learned_ref = \ list(self._hnet.hyper_shapes_learned_ref) if self._hnet.hyper_shapes_distilled is not None: self._hyper_shapes_distilled = \ list(self._hnet.hyper_shapes_distilled) self._has_bias = self._hnet.has_bias # A noise perturbed output can't be considered an FC output anymore. self._has_fc_out = False self._mask_fc_out = self._hnet.mask_fc_out # Guess that's the safest answer. self._has_linear_out = False self._layer_weight_tensors = \ nn.ParameterList(self._hnet.layer_weight_tensors) self._layer_bias_vectors = \ nn.ParameterList(self._hnet.layer_bias_vectors) if self._hnet.batchnorm_layers is not None: self._batchnorm_layers = nn.ModuleList(self._hnet.batchnorm_layers) if self._hnet.context_mod_layers is not None: self._context_mod_layers = \ nn.ModuleList(self._hnet.context_mod_layers) ### Finalize construction ### self._is_properly_setup() if verbose: print('Wrapped a perturbation interface around a hypernetwork.') #print(self) @property def internal_hnet(self): """The underlying hypernetwork that was passed via constructor argument ``hnet``. :type: hnets.hnet_interface.HyperNetInterface """ return self._hnet
[docs] def forward(self, uncond_input=None, cond_input=None, cond_id=None, weights=None, distilled_params=None, condition=None, ret_format='squeezed'): """Compute the weights of a target network. Args: (....): See docstring of method :meth:`hnets.hnet_interface.HyperNetInterface.forward`. Returns: (list or torch.Tensor): See docstring of method :meth:`hnets.hnet_interface.HyperNetInterface.forward`. """ if self._input_handler is not None: uncond_input_int, cond_input_int, cond_id_int = \ self._input_handler(uncond_input=uncond_input, cond_input=cond_input, cond_id=cond_id) else: assert uncond_input is not None uncond_input_int = uncond_input[:, :self._hnet_uncond_in_size] cond_input_int = cond_input cond_id_int = cond_id hnet_out_int = self._hnet.forward(uncond_input=uncond_input_int, cond_input=cond_input_int, cond_id=cond_id_int, weights=weights, distilled_params=distilled_params, condition=condition, ret_format='flattened') if self._output_handler is not None: hnet_out = self._output_handler(hnet_out_int, uncond_input=uncond_input, cond_input=cond_input, cond_id=cond_id) else: assert hnet_out_int.shape == uncond_input.shape hnet_out = hnet_out_int + self._sigma_noise * uncond_input ### Split output into target shapes ### hnet_out = self._flat_to_ret_format(hnet_out, ret_format) return hnet_out
[docs] def distillation_targets(self): """Targets to be distilled after training. See docstring of abstract super method :meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`. Returns: Simply returns the ``distillation_targets`` of the internal hypernet :attr:`internal_hnet``. """ return self.internal_hnet.distillation_targets()
if __name__ == '__main__': pass