Source code for hypnettorch.utils.hmc

#!/usr/bin/env python3
# Copyright 2021 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title          :utils/hmc.py
# @author         :ch
# @contact        :henningc@ethz.ch
# @created        :03/09/2021
# @version        :1.0
# @python_version :3.8.5
r"""
Hamiltonian-Monte-Carlo
-----------------------

The module :mod:`utils.hmc` implements the Hamiltonian-Monte-Carlo (HMC)
algorithm as described in

    Neal, `MCMC using Hamiltonian dynamics <https://arxiv.org/abs/1206.1901>`__,
    2012.

The pseudocode of the algorithm is described in Figure 2 of the paper. The
algorithm uses the Leapfrog algorithm to simulate the Hamiltonian dynamics in
discrete time. Therefore, two crucial hyperparameters are required: the stepsize
:math:`\epsilon` and the number of steps :math:`L`. Both hyperparameters have to
be chosen with care and can drastically influence the behavior of HMC. If the
stepsize :math:`\epsilon` is too small, we don't explore the state space
efficiently and waste computation. If it is too big, the numerical error from
the discretization might be come too huge and the acceptance rate rather low. In
addition, we want to choose :math:`L` large enough to obtain good exploration,
but if we set it too large we might loop back to the starting position.

The No-U-Turn-Sampler (NUTS) has been proposed to set :math:`L` automatically,
such that only the stepsize :math:`\epsilon` has to be chosen.

    Hoffman et al.,
    "`The No-U-Turn Sampler: <https://arxiv.org/abs/1111.4246>`__
    Adaptively Setting Path Lengths in Hamiltonian Monte Carlo", 2011.

This module provides implementations for both variants, basic :class:`HMC` and
:class:`NUTS`. Multiple parallel chains can be simulated via class
:class:`MultiChainHMC`. For Bayesian Neural Networks, the helper function
:func:`nn_pot_energy` can be used to define the potential energy.

**Notation**

We largely follow the notation from
`Neal et al. <https://arxiv.org/abs/1206.1901>`__. The variable of interest,
e.g., model parameters, are encoded by the position vector :math:`q`. In
addition, HMC requires a momentum :math:`p`. The Hamiltonian :math:`H(q, p)`
consists of two terms, the potential energy :math:`U(q)` and the kinetic energy
:math:`K(p) = p^T M^{-1} p / 2` with :math:`M` being a symmetric, p.d. "mass"
matrix.

The Hamiltonian dynamics can thus be summarized as

.. math::

    \frac{dq_i}{dt} &= \frac{\partial H}{\partial p_i} = [M^{-1} p]_i \\
    \frac{dp_i}{dt} &= -\frac{\partial H}{\partial q_i} = \
        - \frac{\partial U}{\partial q_i}

The Leapfrog algorithm is a way to discretize the differential equation above
in a way that is reversible and volumne preserving. The algorithm has two
hyperparameters: the stepsize :math:`\epsilon` and the number of steps
:math:`L`. Below, we sketch the algorithm to update momentum and position from
time :math:`t` to time :math:`t + L\epsilon`.

.. math::

    p_i(t + \frac{\epsilon}{2}) &= p_i(t) - \frac{\epsilon}{2} \
        \frac{\partial U}{\partial q_i} \big( q(t) \big) \\
    q_i(t + l\epsilon) &= q_i(t + (l-1)\epsilon) + \epsilon \
        \frac{p_i(t + (l-1)\epsilon + \epsilon/2)}{m_i} \quad \forall l = 1..L\\
    p_i(t + l\epsilon + \frac{\epsilon}{2}) &= \
        p_i(t + (l-1)\epsilon + \frac{\epsilon}{2}) - \epsilon \
        \frac{\partial U}{\partial q_i} \big( q(t+l\epsilon) \big) \
        \quad \forall l = 1..L-1\\
    p_i(t + L\epsilon) &= p_i(t + (L-1)\epsilon + \frac{\epsilon}{2}) -\
        \frac{\epsilon}{2} \frac{\partial U}{\partial q_i} \
        \big( q(t+L\epsilon) \big)

We assume a diagonal mass matrix in the position update above.

.. autosummary::

    hypnettorch.utils.hmc.HMC
    hypnettorch.utils.hmc.MCMC
    hypnettorch.utils.hmc.MultiChainHMC
    hypnettorch.utils.hmc.NUTS
    hypnettorch.utils.hmc.leapfrog
    hypnettorch.utils.hmc.log_prob_standard_normal_prior
    hypnettorch.utils.hmc.nn_pot_energy
"""
import logging
import numpy as np
from os import path
from queue import Queue
import sys
from tensorboardX import SummaryWriter
from threading import Thread
import torch
from torch.distributions import Normal, MultivariateNormal
import torch.nn.functional as F
from warnings import warn

from hypnettorch.mnets.mnet_interface import MainNetInterface

def _grad_pot_energy(pot_energy, position):
    r"""Compute the partial derivatives of the potential energy for the current
    position :math:`q(t)`.

    Args:
        (....): See docstring of function :func:`leapfrog`.

    Returns:
        (torch.Tensor): The vector of partial derivatives
        :math:`\frac{\partial U}{\partial q} \big( q(t) \big)`.
    """
    pe_val = pot_energy(position)
    pot_grad, = torch.autograd.grad(pe_val, position, only_inputs=True)
    return pot_grad

def _grad_kin_energy(momentum, inv_mass):
    r"""Compute the partial derivatives of the kinetic energy for the current
    momentum :math:`p(t)`.

    This function assumes a kinetic energy of the form
    :math:`K(p) = p^T M^{-1} p / 2`.

    Args:
        (....): See docstring of function :func:`leapfrog`.

    Returns:
        (torch.Tensor): The vector of partial derivatives
        :math:`\frac{\partial K}{\partial p} \big( p(t) \big)`.
    """
    if isinstance(inv_mass, torch.Tensor):
        if inv_mass.numel() == momentum.numel(): # Assuming diagonal mass
            grad_ke = inv_mass * momentum # Element-wise product.
        else: # Assuming full matrix
            grad_ke = torch.matmul(inv_mass, momentum)
    else: # Assuming a single scalar
        grad_ke = inv_mass * momentum

    return grad_ke

def _kin_energy(momentum, inv_mass):
    """Compute the kinetic energy for the current momentum :math:`p(t)`.

    This function assumes a kinetic energy according to
    :math:`K(p) = p^T M^{-1} p / 2`.

    Args:
        (....): See docstring of function :func:`leapfrog`.

    Returns:
        (torch.Tensor): The scalar energy value.
    """
    if isinstance(inv_mass, torch.Tensor):
        if inv_mass.numel() == momentum.numel(): # Assuming diagonal mass
            ke = .5 * torch.dot(momentum, inv_mass * momentum)
        else: # Assuming full matrix
            ke = .5 *  torch.dot(momentum, torch.matmul(inv_mass, momentum))
    else: # Assuming a single scalar
        ke = .5 * inv_mass * torch.dot(momentum, momentum)

    return ke

[docs]def leapfrog(position, momentum, stepsize, num_steps, inv_mass, pot_energy): r"""Implementation of the leapfrog algorithm. The leapfrog algorithm updates position :math:`q` and momentum :math:`p` variables by simulating the Hamiltonian dynamics in discrete time for a time window of size :math:`L\epsilon`, where :math:`L` is the number of leapfrog steps ``num_steps`` and :math:`\epsilon` is the ``stepsize``. In general, one can call this method :math:`L` times while setting ``num_steps=1`` in order to obtain the complete trajectory. However, if not necessary, we recommend setting ``num_steps=L`` to save the unnecessary computation of intermediate momentum variables. Args: position (torch.Tensor): The position variable :math:`q`. momentum (torch.Tensor): The momentum variable :math:`p`. stepsize (float): The leapfrog stepsize :math:`\epsilon`. num_steps (int): The number of leapfrog steps :math:`L`. inv_mass (float or torch.Tensor): The inverse mass matrix :math:`M^{-1}`. Can also be provided as vector, in case of a diagonal mass matrix, or as scalar. pot_energy (func): A function handle that computes the potential energy :math:`U\big( q(t) \big)`, receiving as only input the current position variable. Note: The function handle ``pot_energy`` has to be amenable to :mod:`torch.autograd`, as the momentum update requires the partial derivatives of the potential energy. Returns: (tuple): Tuple containing: - **position** (torch.Tensor): The updated position variable. - **momentum** (torch.Tensor): The updated momentum variable. """ # p(t + \epsilon/2) momentum = momentum - .5 * stepsize * _grad_pot_energy(pot_energy, position) for l in range(num_steps): # Compute the gradient of the kinetic energy. grad_ke = _grad_kin_energy(momentum, inv_mass) # q(t + (l+1)\epsilon) position = position + stepsize * grad_ke if l < num_steps-1: # p(t + (l+1)\epsilon + \epsilon/2) momentum = momentum - \ stepsize * _grad_pot_energy(pot_energy, position) # p(t + L\epsilon) momentum = momentum - .5 * stepsize * _grad_pot_energy(pot_energy, position) return position, momentum
[docs]def log_prob_standard_normal_prior(position, mean=0., std=1.): r"""Log-probability density of a standard normal prior. This function can be used to compute :math:`\log p(q)` for :math:`p(q) = \mathcal{N}(q; \bm{\mu}, I \bm{\sigma}^2)`, where :math:`I` denotes the identity matrix. This function can be passed to :func:`nn_pot_energy` as argument ``prior_log_prob_func`` using, for instance: .. code-block:: python lp_func = lambda q: log_prob_standard_normal_prior(q, mean=0., std=.02) Args: position (torch.Tensor): The position variable :math:`q`. mean (float or torch.Tensor): The mean of the diagonal Gaussian prior. std (float or torch.Tensor): The diagonal covariance of the Gaussian prior. """ prior_dist = Normal(mean, std) return prior_dist.log_prob(position).sum()
[docs]def nn_pot_energy(net, inputs, targets, prior_log_prob_func, tau_pred= 1., nll_type='regression'): r"""The potential energy for Bayesian inference with HMC using neural networks. When obtaining samples from the posterior parameter distribution of a neural network via HMC, a potential energy function has to be specified that allows evaluating the negative log-posterior up to a constant. We consider a neural network with parameters :math:`W` which encodes a likelihood function :math:`p(y \mid W; x)` for an input :math:`x`. In addition, a prior :math:`p(W)` needs to be specified. Given a dataset :math:`\mathcal{D}` consisting of ``inputs`` :math:`x_n` and ``targets`` :math:`y_n`, we can specify the potential energy as (note, here :math:`q = W`) .. math:: U(W) &= - \log p(\mathcal{D} \mid W) - \log p(W) \\ &= - \sum_n \log p(y_n \mid W; x_n) - \log p(W) where the first term corresponds to the negative log-likelihood (NLL). The precise way of computing the NLL depends on which kind of likelihood interpretation is forced onto the network (cf. argument ``nll_type``). Args: net (mnets.mnet_interface.MainNetInterface): The considered neural network, whose parameters are :math:`W`. inputs (torch.Tensor): A tensor containing all the input sample points :math:`x_n` in :math:`\mathcal{D}`. targets (torch.Tensor): A tensor containing all the output sample points :math:`y_n` in :math:`\mathcal{D}`. prior_log_prob_func (func): Function handle that allows computing the log-probability density of the prior for a given position variate. tau_pred (float): Only applies to ``nll_type='regression'``. The inverse variance of the assumed Gaussian likelihood. nll_type (str): The type of likelihood interpretation enforced on the network. The following options are supported: - ``'regression'``: The network outputs the mean of a 1D normal distribution with fixed variance. .. math:: \text{NLL} = \frac{1}{2 \sigma_\text{ll}^2} \ \sum_{(x, y) \in \mathcal{D}} \ \big( f_\text{M}(x, W) - y \big)^2 where :math:`f_\text{M}(x, W)` is the network output and :math:`\frac{1}{\sigma_\text{ll}^2}` corresponds to ``tau_pred``. - ``'classification'``: Multi-class classification with a softmax likelihood. Note, we assume the network has linear (logit) outputs .. math:: \text{NLL} = \sum_{(\mathbf{x}, y) \in \mathcal{D}} \bigg( \ \underbrace{ - \sum_{c=0}^{C-1} [c = y] \log \Big( \ \text{softmax} \big( f_\text{M}(\mathbf{x}, W) \big)_c \ }_{\text{cross-entropy loss with 1-hot targets}} \Big) \ \bigg) where :math:`C` is the number of classes and :math:`y` are integer labels. We assume that the neural network :math:`f_\text{M}(\mathbf{x}, W)` outputs logits. .. note:: We assume ``targets`` contains integer labels and **not** 1-hot encodings for ``'classification'``! Returns: (func): A function handle as required by constructor argument ``pot_energy_func`` of class :class:`HMC`. """ assert nll_type in ['regression', 'classification'] if nll_type != 'regression': assert tau_pred == 1. def pot_energy_func(position): weights = MainNetInterface.flatten_params(position, param_shapes=net.param_shapes, unflatten=True) preds = net(inputs, weights=weights) if nll_type == 'regression': nll = 0.5 * tau_pred * F.mse_loss(preds, targets, reduction='sum') elif nll_type == 'classification': # Note, we assume that `targets` are integer labels! nll = F.cross_entropy(preds, targets, reduction='sum') else: raise NotImplementedError() return nll - prior_log_prob_func(position) return pot_energy_func
[docs]class HMC: r"""This class represents the basic HMC algorithm. The algorithm is implemented as outlined in Fig. 2 of `Neal et al. <https://arxiv.org/abs/1206.1901>`__. The potential energy should be the negative log probability density of the target distribution to sample from (up to a constant) :math:`U(q) = - \log p(q) + \text{const.}`. Args: initial_position (torch.Tensor): The initial position :math:`q(0)`. Note: The position variable should be provided as vector. The weights of a neural network can be flattend via :meth:`mnets.mnet_interface.MainNetInterface.flatten_params`. pot_energy_func (func): A function handle computing the potential energy :math:`U(q)` upon receiving a position :math:`q`. To sample the weights of a neural network, the helper function :func:`nn_pot_energy` can be used. To sample via HMC from a target distribution implemented via :class:`torch.distributions.distribution.Distribution`, one can define a function handle as in the following example. Example: .. code-block:: python d = MultivariateNormal(torch.zeros(4), torch.eye(4)) pot_energy_func = lambda q : - d.log_prob(q) stepsize (float): The stepsize :math:`\epsilon` of the :func:`leapfrog` algorithm. num_steps (int): The number of steps :math:`L` in the :func:`leapfrog` algorithm. inv_mass (float or torch.Tensor): The inverse "mass" matrix as required for the computation of the kinetic energy :math:`K(p)`. See argument ``inv_mass`` of function :func:`leapfrog` for details. logger (logging.Logger, optional): If provided, the progress will be logged. log_interval (int): After how many states the status should be logged. writer (tensorboardX.SummaryWriter, optional): A tensorboard writer. If given, useful simulation data will be logged, like the developement of the Hamiltonian. writer_tag (str): Will be added to the tensorboard tags. """ def __init__(self, initial_position, pot_energy_func, stepsize=.02, num_steps=1, inv_mass=1., logger=None, log_interval=100, writer=None, writer_tag=''): self._position = initial_position if not self._position.requires_grad: self._position.requires_grad = True self._stepsize = stepsize self._num_steps = num_steps self._pot_energy_func = pot_energy_func self._inv_mass = inv_mass self._logger = logger self._log_interval = log_interval self._writer = writer self._writer_tag = writer_tag self._positions = [initial_position] self._num_states = 0 self._accumulated_accept = 0 # Define distribution from which to sample momentum. if isinstance(inv_mass, torch.Tensor) and len(inv_mass.shape) == 2: #mass = torch.inverse(inv_mass) self._momentum_dist = MultivariateNormal( \ torch.zeros_like(initial_position), precision_matrix=inv_mass) else: mass = 1. / inv_mass # Note, that we need to pass the standard deviation and not # variance. self._momentum_dist = Normal(torch.zeros_like(initial_position), mass**0.5) @property def stepsize(self): """The stepsize :math:`\epsilon` of the :func:`leapfrog` algorithm. You may adapt the stepsize at any point. :type: float """ return self._stepsize @stepsize.setter def stepsize(self, value): self._stepsize = value @property def num_steps(self): """The number of steps :math:`L` in the :func:`leapfrog` algorithm. You may adapt the number of steps at any point. :type: int """ return self._num_steps @num_steps.setter def num_steps(self, value): self._num_steps = value @property def current_position(self): """The latest position :math:`q(t)` in the chain simulated so far. :type: torch.Tensor """ return self._position @property def num_states(self): """The number of states in the chain visited so far. The counter will be increased by method :meth:`simulate_chain`. :type: int """ return self._num_states @property def position_trajectory(self): """A list containing all position variables (Markov states) visited so far. New positions will be added by the method :meth:`simulate_chain`. To decrease the memory footprint of objects in this class, the trajectory can be cleared via method :meth:`clear_position_trajectory`. :type: list """ return self._positions @property def acceptance_probability(self): """The fraction of states that have been accepted. :type: float """ if self.num_states == 0: return 1.0 return self._accumulated_accept / self._num_states
[docs] def clear_position_trajectory(self, n=None): """Reset attribute :attr:`position_trajectory`. This method will no affect the counter :attr:`num_states`. Args: n (int, optional): If provided, only the first ``n`` elements of :attr:`position_trajectory` are discarded (e.g., the burn-in samples). """ if n is not None: self._positions = self._positions[n:] else: self._positions = []
[docs] def simulate_chain(self, n): """Simulate the next ``n`` states of the chain. The new states will be appended to attribute :attr:`position_trajectory`. Args: n (int): Number of HMC steps to be executed. """ logger = self._logger writer = self._writer for _ in range(n): curr_q = self.current_position # Resample momentum. curr_p = self._momentum_dist.sample() #curr_p.requires_grad = True # Simulate Hamiltonian dynamics. q = curr_q.detach().clone() p = curr_p.detach().clone() if not q.requires_grad: q.requires_grad = True q, p = leapfrog(q, p, self.stepsize, self.num_steps, self._inv_mass, self._pot_energy_func) # Negation of momentum not required in simulation. #p = -p # Evaluate Hamiltonian at beginning and end of trajectory. k_p_start = _kin_energy(curr_p, self._inv_mass) u_q_start = self._pot_energy_func(curr_q) k_p_proposal = _kin_energy(p, self._inv_mass) u_q_proposal = self._pot_energy_func(q) # Metropolis update. if torch.rand(1).to(p.device) < torch.exp(u_q_start-u_q_proposal + \ k_p_start-k_p_proposal): accept = True self._accumulated_accept += 1 self._positions.append(q) else: # Reject accept = False self._positions.append(curr_q.clone()) self._position = self._positions[-1] self._num_states += 1 # Log progress. if accept: kinetic = k_p_proposal.detach().cpu().numpy() potential = u_q_proposal.detach().cpu().numpy() else: kinetic = k_p_start.detach().cpu().numpy() potential = u_q_start.detach().cpu().numpy() hamiltonian = kinetic + potential if logger is not None and \ (self.num_states-1) % self._log_interval == 0: logger.debug('HMC state %d: Current Hamiltonian: %f - ' \ % (self.num_states, hamiltonian) + \ 'Acceptance probability: %.2f%%.' \ % (self.acceptance_probability * 100)) if writer is not None: tag = self._writer_tag writer.add_scalar('%shmc/kinetic' % tag, kinetic, global_step=self.num_states, display_name='Kinetic Energy') writer.add_scalar('%shmc/potential' % tag, potential, global_step=self.num_states, display_name='Potential Energy') writer.add_scalar('%shmc/hamiltonian' % tag, hamiltonian, global_step=self.num_states, display_name='Hamiltonian') writer.add_scalar('%shmc/accept' % tag, self.acceptance_probability, global_step=self.num_states, display_name='Acceptance Probability')
[docs]class NUTS(HMC): r"""HMC with No U-Turn Sampler (NUTS). In this class, we implement the efficient version of the NUTS algorithm (see algorithm 3 in `Hoffman et al. <https://arxiv.org/abs/1111.4246>`__). NUTS eliminates the need to choose the number of Leapfrog steps :math:`L`. While the algorithm is more computationally expensive than basic HMC, the reduced hyperparameter effort has been shown to reduce the overall computational cost (and it requires less human intervention). As explained in the paper, a good heuristic to set :math:`L` is to choose the highest number (for given :math:`\epsilon`) before the trajectory loops back to the initial position :math:`q_0`, e.g., when the following quantity becomes negative .. math:: \frac{d}{dt} \frac{1}{2} \lVert q - q_0 \rVert_2^2 = \ \langle q- q_0, p \rangle Note, this equation assumes the `mass matrix` is the identity: :math:`M=I`. However, this approach is in general not time reversible, therefore NUTS proposes a recursive agorithm that allows backtracing. NUTS randomly adds subtrees to a balanced binary tree and stops when any of those subtrees starts making a "U-turn" (either forward or backward in time). This tree construction is fully symmetric and therefore reversible. Note: The NUTS paper also proposes to combine a heuristic approach to adapt the stepsize :math:`\epsilon` together with :math:`L` (e.g., see algorithm 6 in `Hoffman et al. <https://arxiv.org/abs/1111.4246>`__). Such stepsize adaptation is currently not implemented by this class! Args: (....): See docstring of class :class:`HMC`. delta_max (float): The nonnegative criterion :math:`\Delta_\text{max}` from Eq. 8 of `Hoffman et al. <https://arxiv.org/abs/1111.4246>`__, that should ensure that we stop NUTS if the energy becomes too big. """ def __init__(self, initial_position, pot_energy_func, stepsize=.02, delta_max=1000., inv_mass=1., logger=None, log_interval=100, writer=None, writer_tag=''): HMC.__init__(self, initial_position, pot_energy_func, stepsize=stepsize, num_steps=None, inv_mass=inv_mass, logger=logger, log_interval=log_interval, writer=writer, writer_tag=writer_tag) self._delta_max = delta_max # Overwrite base attribute. @property def num_steps(self): """The attribute :attr:`HMC.num_steps` does not exist for class :class:`NUTS`! Accessing this attribute will cause an error. """ raise RuntimeError('NUTS has no attribute "num_steps".') @num_steps.setter def num_steps(self, value): raise RuntimeError('NUTS has no attribute "num_steps".') def _u_turn(self, q1, q2, p): """Detect whether a U-Turn has been made. This method simply computes .. math:: \Big[ \langle q2- q_1, M^-1 p \rangle \leq 0 \Big] where :math:`[\cdot]` denotes the Iverson bracket. Returns: (int) """ # Note that the product of inverse mass matrix and momentum is simply # the gradient of the kinetic energy. angle = torch.dot(q2 - q1, _grad_kin_energy(p, self._inv_mass)) return int(angle >= 0) def _build_tree(self, q, p, u, v, j): """Build the NUTS tree recursively. See function "BuildTree" in algorithm 3 of the NUTS paper. """ if j == 0: q1, p1 = leapfrog(q, p, v * self.stepsize, 1, self._inv_mass, self._pot_energy_func) # The log-probability is up to additive constants the negative # total energy (or hamiltonian). k_p1 = _kin_energy(p1, self._inv_mass) u_q1 = self._pot_energy_func(q1) log_prob = -u_q1 - k_p1 log_u = torch.log(u) n1 = int(log_u <= log_prob) s1 = int(log_prob > log_u - self._delta_max) return q1, p1, q1, p1, q1, n1, s1 else: q_m, p_m, q_p, p_p, q1, n1, s1 = self._build_tree(q, p, u, v, j-1) if s1 == 1: if v < 0: q_m, p_m, _, _, q2, n2, s2 = self._build_tree(q_m, p_m, u, v, j-1) else: _, _, q_p, p_p, q2, n2, s2 = self._build_tree(q_p, p_p, u, v, j-1) n = n1 + n2 if n > 0: # This step is a bit ambiguous in the pseudo-code! if torch.rand(1) <= n2 / n: q1 = q2 s1 = s2 * self._u_turn(q_m, q_p, p_m) * \ self._u_turn(q_m, q_p, p_p) n1 = n1 + n2 return q_m, p_m, q_p, p_p, q1, n1, s1
[docs] def simulate_chain(self, n): """Simulate the next ``n`` states of the chain. The new states will be appended to attribute :attr:`position_trajectory`. Args: n (int): Number of HMC steps to be executed. """ logger = self._logger writer = self._writer device = self.current_position.device for _ in range(n): curr_q = self.current_position # Resample momentum. curr_p = self._momentum_dist.sample() # Sample slice variable. curr_K = _kin_energy(curr_p, self._inv_mass) curr_U = self._pot_energy_func(curr_q) curr_prob = torch.exp(-curr_U - curr_K) # unnormalized probability u = torch.rand(1).to(device) * curr_prob # Initialize some variables. q_minus = curr_q p_minus = curr_p q_plus = curr_q p_plus = curr_p # The new state. q_new = None j = 0 n = 1 s = 1 while s == 1: # Choose random direction. v = -1 if torch.rand(1) < .5 else 1 if v == -1: q_minus, p_minus, _, _, q_proposed, n_new, s_new = \ self._build_tree(q_minus, p_minus, u, v, j) else: _, _, q_plus, p_plus, q_proposed, n_new, s_new = \ self._build_tree(q_plus, p_plus, u, v, j) if s_new == 1: if torch.rand(1) < min(1, n_new/n): q_new = q_proposed n += n_new s = s_new * self._u_turn(q_minus, q_plus, p_minus) * \ self._u_turn(q_minus, q_plus, p_plus) j += 1 if q_new is not None: self._accumulated_accept += 1 self._positions.append(q_new.detach().clone()) new_U = self._pot_energy_func(q_new) else: self._positions.append(curr_q.detach().clone()) new_U = curr_U if not self._positions[-1].requires_grad: self._positions[-1].requires_grad = True new_U = new_U.detach().cpu().numpy() self._position = self._positions[-1] self._num_states += 1 # Log progress. if logger is not None and \ (self.num_states-1) % self._log_interval == 0: logger.debug('NUTS state %d: Current Potential Energy: %f - ' \ % (self.num_states, new_U) + \ 'Acceptance probability: %.2f%%.' \ % (self.acceptance_probability * 100)) if writer is not None: tag = self._writer_tag writer.add_scalar('%snuts/potential' % tag, new_U, global_step=self.num_states, display_name='Potential Energy') writer.add_scalar('%snuts/accept' % tag, self.acceptance_probability, global_step=self.num_states, display_name='Acceptance Probability') writer.add_scalar('%snuts/tree_depth' % tag, j, global_step=self.num_states, display_name='Tree Depth')
[docs]class MultiChainHMC(): r"""Wrapper for running multiple HMC chains in parallel. Samples obtained via an MCMC sampler are highly auto-correlated for two reasons: (1) the proposal distribution is conditioned on the previous state and (2) because of rejection (consecutive states are identical). In addition, it is unclear when the chain is long enough such that sufficient exploration has been taking place and the sample (excluding initial burn-in) can be considered an i.i.d. sample from the target distribution. For this reason, it is recommended to obtain an MCMC sample by running multiple chains in parrallel, starting from varying initial postitions :math:`q(0)`. This class provides a simple wrapper to instantiate multiple chains from :class:`HMC` (and its subclasses) and provides an interface to easily simulate those chains. Args: initial_positions (list or tuple): A list of initial positions. The length of this list will determine the number of chains to be instantiated. Each element is an initial position as described for argument ``initial_position`` of class :class:`HMC`. pot_energy_func (func): See docstring of class :class:`HMC`. One may also provide a list of functions. For instance, if the potential energy of a Bayesian neural network should be computed, there might be a runtime speedup if each function uses separate model instance. chain_type (str): The of HMC algorithm to be used. The following options are available: - ``'hmc'``: Each chain will be an instance of class :class:`HMC`. - ``'nuts'``: Each chain will be an instance of class :class:`NUTS`. **kwargs: Keyword arguments that will be passed to the constructor when instantiating each chain. The following particularities should be noted. - If a ``writer`` object is passed, then a chain-specific identifier is added to the corresponding ``writer_tag``, except if ``writer`` is a string. In this case, we assume ``writer`` corresponds to an output directory and we construct a separate object of class :class:`tensorboardX.SummaryWriter` per chain. In the latter case, the scalars logged across chains are all shown within the same tensorboard plot and are therefore easier comparable. - If a ``logger`` object is passed, then it will only be provided to the first chain. If a logger should be passed to multiple chain instances, then a list of objects from class :class:`logging.Logger` is required. If entries in this list are ``None``, then a simple console logger is generated for these entries that displays the chain's identity when logging a message. """ def __init__(self, initial_positions, pot_energy_func, chain_type='hmc', **kwargs): assert chain_type in ['hmc', 'nuts'] self._num_chains = len(initial_positions) # Determine `writer` and `logger` per chain. writers = None if 'writer' in kwargs.keys(): writers = kwargs['writer'] writer_tags = None writer_tag = '' if 'writer_tag' in kwargs.keys(): writer_tag = kwargs['writer_tag'] loggers = None if 'logger' in kwargs.keys(): loggers = kwargs['logger'] self._close_writers = False if writers is not None: if isinstance(writers, SummaryWriter): # Same writer for all chains, but different tags. writers = [writers] * self.num_chains writer_tags = [writer_tag + 'chain_%d/' % ii for ii in \ range(self.num_chains)] else: assert isinstance(writers, str) summary_dir = writers writers = [SummaryWriter(logdir=path.join(summary_dir, \ 'chain_%d' % ii)) \ for ii in range(self.num_chains)] # Since we created these writer objects, we should explicitly # close them when the object is destroyed. self._close_writers = True self._writers = writers if loggers is not None: if isinstance(loggers, logging.Logger): logger = loggers loggers = [None] * self.num_chains loggers[0] = logger else: for ii, logger in enumerate(loggers): if logger is None: loggers[ii] = \ MultiChainHMC._get_chain_specific_logger(ii) pot_energy_funcs = pot_energy_func if not isinstance(pot_energy_func, (list, tuple)): pot_energy_funcs = [pot_energy_func] * self.num_chains # Instantiate chains. self._chains = [] for cidx in range(self.num_chains): if writers is not None: kwargs['writer'] = writers[cidx] if writer_tags is not None: kwargs['writer_tag'] = writer_tags[cidx] if loggers is not None: kwargs['logger'] = loggers[cidx] pe_func = pot_energy_funcs[cidx] chain = None if chain_type == 'hmc': chain = HMC(initial_positions[cidx], pe_func, **kwargs) elif chain_type == 'nuts': chain = NUTS(initial_positions[cidx], pe_func, **kwargs) self._chains.append(chain) @property def num_chains(self): """The number of chains managed by this instance. :type: int """ return self._num_chains @property def chains(self): """The list of internally managed HMC objects. :type: list """ return self._chains @property def avg_acceptance_probability(self): """The average fraction of states that have been accepted across all chains. :type: float """ avg_ap = np.mean([c.acceptance_probability for c in self.chains]) return float(avg_ap)
[docs] def simulate_chains(self, num_states, num_chains=-1, num_parallel=1): """Simulate the chains to gather a certain number of new positions. This method simulates the internal chains to add ``num_states`` positions to each considered chain. Args: num_states (int): Each considered chain will be simulated for this amount of HMC steps (see argument ``n`` of method :math:`HMC.simulate_chain`). num_chains (int or list): The number of chains to be considered. If ``-1``, then all chains will be simulated for ``num_states`` steps. Otherwise, the ``num_chains`` chains with the lowest number of states so far (according to attribute :attr:`HMC.num_states`) is simulated. Alternatively, one may specify a list of chain indices (numbers between 0 and :attr:`num_chains`). num_parallel (int): How many chains should be simulated in parallel. If ``1``, the chains are simulated consecutively (one after another). """ if num_chains == -1: chain_ids = list(range(self.num_chains)) elif isinstance(num_chains, (list, tuple)): chain_ids = num_chains if len(chain_ids) != len(set(chain_ids)): warn('Duplicates found in argument "num_chains".') chain_ids = list(set(chain_ids)) else: chain_lengths = [(i, o.num_states) \ for i, o in enumerate(self.chains)] chain_lengths.sort(key=lambda tup: tup[1], reverse=False) num_chains = min(num_chains, self.num_chains) chain_ids = [t[0] for t in chain_lengths] # Put all chains to be simulated in a queue. q = Queue() for i in chain_ids: q.put(self.chains[i]) def worker(): # Code will be executed in separate thread. while not q.empty(): hmc_obj = q.get() hmc_obj.simulate_chain(num_states) q.task_done() # Create threads that simulate the chains in parallel. threads = [] for i in range(num_parallel): t = Thread(target=worker) t.start() threads.append(t) # Block until all queue items are marked as done. q.join() for t in threads: t.join()
def __del__(self): # Destructor if self._close_writers: for writer in self._writers: writer.close() @staticmethod def _get_chain_specific_logger(chain_id): """Create a chain-specific logger instance. Args: chain_id (int): The chain identifier. Returns: (logging.Logger) """ stream_formatter = logging.Formatter( \ fmt='%(asctime)s - %(levelname)s' + ' - chain %d' % chain_id \ + ' - %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p') stream_handler = logging.StreamHandler(sys.stdout) stream_handler.setFormatter(stream_formatter) #stream_handler.setLevel(logging.DEBUG) logger = logging.getLogger('logger_chain%d' % chain_id) logger.setLevel(logging.DEBUG) logger.addHandler(stream_handler) return logger
[docs]class MCMC: r"""Implementation of the Metropolis-Hastings algorithm. This class implements the basic Metropolis-Hastings algorithm as, for instance, outlined `here <https://arxiv.org/abs/1504.01896>`__ (see alg. 1). The Metropolis-Hastings algorithm is a simple MCMC algorithm. In contrast to :class:`HMC`, sampling is slow as positions follow a random walk. However, the algorithm does not need access to gradient information, which makes it applicable to a wider range of applications. We use a normal distribution :math:`\mathcal{N}(p, \sigma^2 I)` as proposal, where :math:`p` denotes the previous position (sample point). Thus, the proposal is symmetric, and cancels in the MH steps. The potential energy is expected to be passed as negative log-probability (up to a constant), such that .. math:: \frac{\pi(\tilde{p}_t)}{\pi(p_{t-1})} \propto \ \exp \big\{ U(p_{t-1}) - U(\tilde{p}_t) \big\} Args: (....): See docstring of class :class:`HMC`. proposal_std (float): The standard deviation :math:`\sigma` of the proposal distribution :math:`\tilde{p}_t \sim q(p \mid p_{t-1})`. """ def __init__(self, initial_position, pot_energy_func, proposal_std=1., logger=None, log_interval=100, writer=None, writer_tag=''): self._position = initial_position if not self._position.requires_grad: self._position.requires_grad = True self._proposal_std = proposal_std self._pot_energy_func = pot_energy_func self._logger = logger self._log_interval = log_interval self._writer = writer self._writer_tag = writer_tag self._positions = [initial_position] self._num_states = 0 self._accumulated_accept = 0 @property def proposal_std(self): """The std :math:`\sigma` of the proposal distribution. :type: float """ return self._proposal_std @proposal_std.setter def proposal_std(self, value): self._proposal_std = value @property def current_position(self): """The latest position :math:`q(t)` in the chain simulated so far. :type: torch.Tensor """ return self._position @property def num_states(self): """The number of states in the chain visited so far. The counter will be increased by method :meth:`simulate_chain`. :type: int """ return self._num_states @property def position_trajectory(self): """A list containing all position variables (Markov states) visited so far. New positions will be added by the method :meth:`simulate_chain`. To decrease the memory footprint of objects in this class, the trajectory can be cleared via method :meth:`clear_position_trajectory`. :type: list """ return self._positions @property def acceptance_probability(self): """The fraction of states that have been accepted. :type: float """ if self.num_states == 0: return 1.0 return self._accumulated_accept / self._num_states
[docs] def clear_position_trajectory(self, n=None): """Reset attribute :attr:`position_trajectory`. This method will no affect the counter :attr:`num_states`. Args: n (int, optional): If provided, only the first ``n`` elements of :attr:`position_trajectory` are discarded (e.g., the burn-in samples). """ if n is not None: self._positions = self._positions[n:] else: self._positions = []
[docs] def simulate_chain(self, n): """Simulate the next ``n`` states of the chain. The new states will be appended to attribute :attr:`position_trajectory`. Args: n (int): Number of MCMC steps to be executed. """ logger = self._logger writer = self._writer for _ in range(n): curr_q = self.current_position # Sample new proposal. eps = torch.normal(torch.zeros_like(curr_q), 1) q = curr_q + self.proposal_std * eps # Evaluate Hamiltonian at beginning and end of trajectory. u_q_start = self._pot_energy_func(curr_q) u_q_proposal = self._pot_energy_func(q) # Metropolis update. if torch.rand(1).to(q.device) < min(1, torch.exp(u_q_start - \ u_q_proposal)): accept = True self._accumulated_accept += 1 self._positions.append(q) else: # Reject accept = False self._positions.append(curr_q.clone()) self._position = self._positions[-1] self._num_states += 1 # Log progress. if accept: potential = u_q_proposal.detach().cpu().numpy() else: potential = u_q_start.detach().cpu().numpy() if logger is not None and \ (self.num_states-1) % self._log_interval == 0: logger.debug('MH state %d: Current Pot. Energy: %f - ' \ % (self.num_states, potential) + \ 'Acceptance probability: %.2f%%.' \ % (self.acceptance_probability * 100)) if writer is not None: tag = self._writer_tag writer.add_scalar('%smh/potential' % tag, potential, global_step=self.num_states, display_name='Potential Energy') writer.add_scalar('%smh/accept' % tag, self.acceptance_probability, global_step=self.num_states, display_name='Acceptance Probability')
if __name__ == '__main__': pass