Source code for hypnettorch.mnets.classifier_interface

#!/usr/bin/env python3
# Copyright 2019 Christian Henning
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# @title          :mnets/
# @author         :ch
# @contact
# @created        :09/20/2019
# @version        :1.0
# @python_version :3.6.8
Interface for Classifiers

A general interface for main networks used in classification tasks. This
abstract base class also provides a collection of static helper functions that
are useful in classification problems.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from warnings import warn

from hypnettorch.mnets.mnet_interface import MainNetInterface

[docs]class Classifier(nn.Module, MainNetInterface): """A general interface for classification networks.""" def __init__(self, num_classes, verbose): """Initialize the network. Args: num_classes: The number of output neurons. verbose: Allow printing of general information about the generated network (such as number of weights). """ # FIXME find a way using super to handle multiple inheritance. #super(Classifier, self).__init__() nn.Module.__init__(self) MainNetInterface.__init__(self) assert(num_classes > 0) self._num_classes = num_classes self._verbose = verbose @property def num_classes(self): """Number of output neurons. :type: int """ return self._num_classes
[docs] @staticmethod def logit_cross_entropy_loss(h, t, reduction='mean'): """Compute cross-entropy loss for given predictions and targets. Note, we assume that the argmax of the target vectors results in the correct label. Args: h: Unscaled outputs from the main network, i.e., activations of the last hidden layer (unscaled logits). t: Targets in form os soft labels or 1-hot encodings. reduction (str): The reduction method to be passed to :func:`torch.nn.functional.cross_entropy`. Returns: Cross-entropy loss computed on logits h and labels extracted from target vector t. """ assert(t.shape[1] == h.shape[1]) targets = t.argmax(dim=1, keepdim=False) return F.cross_entropy(h, targets, reduction=reduction)
[docs] @staticmethod def knowledge_distillation_loss(logits, target_logits, target_mapping=None, device=None, T=2.): """Compute the knowledge distillation loss as proposed by Hinton et al., "Distilling the Knowledge in a Neural Network", NIPS Deep Learning and Representation Learning Workshop, 2015. Args: logits: Unscaled outputs from the main network, i.e., activations of the last hidden layer (unscaled logits). target_logits: Target logits, i.e., activations of the last hidden layer (unscaled logits) from the target model. Note, we won't detach "target_logits" from the graph. Make sure, that you do this before calling this method. target_mapping: In continual learning, it might be that the output layer size of a model is growing. Thus, it could be that the model providing the ``target_logits`` has a smaller output size than the current model providing the ``logits``. Therefore, one has to provide a mapping, which is a list of indices for ``logits`` that state which activations in ``logits`` have a corresponding target in ``target_logits``. For instance, if the output layer size just increased by 1 through appending a new output neuron to the current model, the mapping would simply be: :code:`target_mapping = list(range(target_logits.shape[1]))`. device: Current PyTorch device. Only needs to be specified if "target_mapping" is given. T: Softmax temperature. Returns: Knowledge Distillation (KD) loss. """ assert target_mapping is None or device is not None targets = F.softmax(target_logits / T, dim=1) n_classes = logits.shape[1] n_targets = targets.shape[1] if target_mapping is None: if n_classes != n_targets: raise ValueError('If sizes of "logits" and "target_logits" ' + 'differ, "target_mapping" must be specified.') else: new_targets = torch.zeros_like(logits).to(device) new_targets[:, target_mapping] = targets targets = new_targets # Note, I think the multiplication with T^2 here is not necessary. The # original paper prescribes it, but on a gradient analysis where the # MSE between tempered softmax targets and predictions is minimized # (assuming the logits are zero-mean). Here, the gradient should consist # of two terms where the first is scaled by 1/T and the second can be # considered as scaled by 1/T^2 if making the same assumption as the # distillation paper. # Though, I wouldn't change any of this, since the loss function has # been used and I don't think it matter for reasonable temperature # choices. return -(targets * F.log_softmax(logits / T,dim=1)).sum(dim=1).mean()*\ T**2
[docs] @staticmethod def softmax_and_cross_entropy(h, t, reduction_sum=False): """Compute the cross entropy from logits, allowing smoothed labels (i.e., this function does not require 1-hot targets). Args: h: Unscaled outputs from the main network, i.e., activations of the last hidden layer (unscaled logits). t: Targets in form os soft labels or 1-hot encodings. Returns: Cross-entropy loss computed on logits h and given targets t. """ assert(t.shape[1] == h.shape[1]) loss = -(t * torch.nn.functional.log_softmax(h, dim=1)).sum(dim=1) if reduction_sum: return loss.sum() else: return loss.mean()
[docs] @staticmethod def accuracy(y, t): """Computing the accuracy between predictions y and targets t. We assume that the argmax of t results in labels as described in the docstring of method "cross_entropy_loss". Args: y: Outputs from the main network. t: Targets in form of soft labels or 1-hot encodings. Returns: Relative prediction accuracy on the given batch. """ assert(t.shape[1] == y.shape[1]) predictions = y.argmax(dim=1, keepdim=False) targets = t.argmax(dim=1, keepdim=False) return (predictions == targets).float().mean()
[docs] @staticmethod def num_hyper_weights(dims): """The number of weights that have to be predicted by a hypernetwork. .. deprecated:: 1.0 Please use method :meth:`mnets.mnet_interface.MainNetInterface.shapes_to_num_weights` instead. Args: dims: For instance, the attribute :attr:`hyper_shapes`. Returns: (int) """ warn('Please use class "mnets.mnet_interface.MainNetInterface.' + 'shapes_to_num_weights" instead.', DeprecationWarning) return np.sum([ for l in dims])
if __name__ == '__main__': pass