#!/usr/bin/env python3
# Copyright 2019 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title :permuted_mnist.py
# @author :ch
# @contact :henningc@ethz.ch
# @created :04/11/2019
# @version :1.0
# @python_version :3.6.7
"""
Permuted MNIST Dataset
^^^^^^^^^^^^^^^^^^^^^^
The module :mod:`data.special.permuted_mnist` contains a data handler for the
permuted MNIST dataset.
"""
import copy
import numpy as np
from hypnettorch.data.mnist_data import MNISTData
[docs]class PermutedMNISTList():
"""A list of permuted MNIST tasks that only uses a single instance of class
:class:`PermutedMNIST`.
An instance of this class emulates a Python list that holds objects of
class :class:`PermutedMNIST`. However, it doesn't actually hold several
objects, but only one with just the permutation matrix being exchanged
everytime a different element of this list is retrieved. Therefore, **use
this class with care**!
- As all list entries are the same PermutedMNIST object, one should
never work with several list entries at the same time!
-> **Retrieving a new list entry will modify every previously
retrieved list entry!**
- When retrieving a slice, a shallow copy of this object is created
(i.e., the underlying :class:`PermutedMNIST` does not change) with
only the desired subgroup of permutations avaliable.
Why would one use this object? When working with many permuted MNIST tasks,
then the memory consumption becomes significant if one desires to hold all
task instances at once in working memory. An object of this class only needs
to hold the MNIST dataset once in memory. Just the number of permutation
matrices grows linearly with the number of tasks.
Caution:
**You may never use more than one entry of this class at the same
time**, as all entries share the same underlying data object and
therewith the same permutation.
Note:
The mini-batch generation process is maintained separately for every
permutation. Thus, the retrieval of mini-batches for different
permutations does not influence one another.
Example:
You should **never** use this list as follows
.. code-block:: python
dhandlers = PermutedMNISTList(permutations, '/tmp')
d0 = dhandlers[0]
# Zero-th permutation is active ...
# ...
d1 = dhandlers[1]
# First permutation is active for `d0` and `d1`!
# Important, you may not use `d0` anymore, as this might lead to
# undesired behavior.
Example:
Instead, always work with only one list entry at a time. The following
usage would be **correct**
.. code-block:: python
dhandlers = PermutedMNISTList(permutations, '/tmp')
d = dhandlers[0]
# Zero-th permutation is active ...
# ...
d = dhandlers[1]
# First permutation is active for `d` as expected.
Args:
(....): See docstring of constructor of class :class:`PermutedMNIST`.
permutations: A list of permutations (see parameter ``permutation``
of class :class:`PermutedMNIST` to have a description of valid list
entries). The length of this list denotes the number of tasks.
show_perm_change_msg: Whether to print a notification everytime the
data permutation has been exchanged. This should be enabled
during developement such that a proper use of this list is
ensured. **Note** You may never work with two elements of this
list at a time.
"""
def __init__(self, permutations, data_path, use_one_hot=True,
validation_size=0, padding=0, trgt_padding=None,
show_perm_change_msg=True):
print('Loading MNIST into memory, that is shared among %d permutation '
% (len(permutations)) + 'tasks.')
self._data = PermutedMNIST(data_path, use_one_hot=use_one_hot,
validation_size=validation_size, permutation=None, padding=padding,
trgt_padding=trgt_padding)
self._permutations = permutations
self._show_perm_change_msg = show_perm_change_msg
# To ensure that we do not disturb the randomness inside each Dataset
# object, we store the corresponding batch generators internally.
# In this way, we don't break the randomness used to generate batches
# (or the order for deterministically retrieved minibatches, such as
# test batches).
self._batch_gens_train = [None] * len(permutations)
self._batch_gens_test = [None] * len(permutations)
self._batch_gens_val = [None] * len(permutations)
# Sanity check! Assert that the implementation inside the `Dataset`
# class hasn't changed.
assert hasattr(self._data, '_batch_gen_train') and \
self._data._batch_gen_train is None
assert hasattr(self._data, '_batch_gen_test') and \
self._data._batch_gen_test is None
assert hasattr(self._data, '_batch_gen_val') and \
self._data._batch_gen_val is None
# Index of the currently active permutation.
self._active_perm = -1
def __len__(self):
"""Number of tasks."""
return len(self._permutations)
def __getitem__(self, index):
"""Return the underlying data object with the index'th permutation.
Args:
index: Index of task for which data should be returned.
Return:
The data loader for task ``index``.
"""
### User Warning ###
color_start = '\033[93m'
color_end = '\033[0m'
help_msg = 'To disable this message, disable the flag ' + \
'"show_perm_change_msg" when calling the constructor of class ' + \
'classifier.permuted_mnist.PermutedMNISTList.'
####################
if isinstance(index, slice):
new_list = copy.copy(self)
new_list._permutations = self._permutations[index]
new_list._batch_gens_train = self._batch_gens_train[index]
new_list._batch_gens_test = self._batch_gens_test[index]
new_list._batch_gens_val = self._batch_gens_val[index]
### User Warning ###
if self._show_perm_change_msg:
indices = list(range(*index.indices(len(self))))
print(color_start + 'classifier.permuted_mnist.' +
'PermutedMNISTList: A slice of permutations with ' +
'indices %s has been created. ' % indices +
'The applied permutation has not changed! ' + color_end +
help_msg)
####################
return new_list
assert(isinstance(index, int))
# Backup batch generator to preserve random behavior.
if self._active_perm != -1:
self._batch_gens_train[self._active_perm] = \
self._data._batch_gen_train
self._batch_gens_test[self._active_perm] = \
self._data._batch_gen_test
self._batch_gens_val[self._active_perm] = self._data._batch_gen_val
self._data.permutation = self._permutations[index]
self._data._batch_gen_train = self._batch_gens_train[index]
self._data._batch_gen_test = self._batch_gens_test[index]
self._data._batch_gen_val = self._batch_gens_val[index]
self._active_perm = index
### User Warning ###
if self._show_perm_change_msg:
color_start = '\033[93m'
color_end = '\033[0m'
print(color_start + 'classifier.permuted_mnist.PermutedMNISTList:' +
' Data permutation has been changed to %d. ' % index +
color_end + help_msg)
####################
return self._data
def __setitem__(self, key, value):
"""Not implemented."""
raise NotImplementedError('Not yet implemented!')
def __delitem__(self, key):
"""Not implemented."""
raise NotImplementedError('Not yet implemented!')
[docs]class PermutedMNIST(MNISTData):
"""An instance of this class shall represent the permuted MNIST dataset,
which is the same as the MNIST dataset, just that input pixels are shuffled
by a random matrix.
Note:
Image transformations are computed on the fly when transforming batches
to torch tensors. Hence, this class is only applicable to PyTorch
applications. Internally, the class stores the unpermuted images.
Args:
data_path: Where should the dataset be read from? If not existing,
the dataset will be downloaded into this folder.
use_one_hot: Whether the class labels should be represented in a
one-hot encoding.
validation_size: The number of validation samples. Validation
samples will be taking from the training set (the first :math:`n`
samples).
permutation: The permutation that should be applied to the dataset.
If ``None``, no permutation will be applied. We expect a numpy
permutation of the form
:code:`np.random.permutation((28+2*padding)**2)`
padding: The amount of padding that should be applied to images.
.. note::
The padding is currently not reflected in the
`:attr:`data.dataset.Dataset.in_shape` attribute, as the padding
is only applied to torch tensors. See attribute
:attr:`torch_in_shape`.
trgt_padding (int, optional): If provided, ``trgt_padding`` fake classes
will be added, such that in total the returned dataset has
``len(labels) + trgt_padding`` classes. However, all padded classes
have no input instances. Note, that 1-hot encodings are padded to
fit the new number of classes.
"""
def __init__(self, data_path, use_one_hot=True, validation_size=0,
permutation=None, padding=0, trgt_padding=None):
# Note, image data augmentation doesn't make sense for a dataset that
# can't be view as images due to the random permutations.
super().__init__(data_path, use_one_hot=use_one_hot,
validation_size=validation_size,
use_torch_augmentation=False)
self._padding = padding
self._input_dim = (28+padding*2)**2
self._permutation = permutation # See setter below.
if trgt_padding is not None and trgt_padding > 0:
print('PermutedMNIST targets will be padded with %d zeroes.' \
% trgt_padding)
self._data['num_classes'] += trgt_padding
if self.is_one_hot:
self._data['out_shape'] = [self._data['out_shape'][0] + \
trgt_padding]
out_data = self._data['out_data']
self._data['out_data'] = np.concatenate((out_data,
np.zeros((out_data.shape[0], trgt_padding))), axis=1)
@property
def permutation(self):
"""The permuation matrix that is applied to input images before they are
transformed to Torch tensors."""
return self._permutation
@permutation.setter
def permutation(self, value):
self._permutation = value
self._transform = PermutedMNIST.torch_input_transforms(
padding=self._padding, permutation=value)
@property
def torch_in_shape(self):
"""The input shape of images, similar to attribute `in_shape`.
In contrast to `in_shape`, this attribute reflects the padding that is
applied when calling
:meth:`classifier.permuted_mnist.PermutedMNIST.input_to_torch_tensor`.
"""
return [self.in_shape[0] + 2 * self._padding,
self.in_shape[1] + 2 * self._padding, self.in_shape[2]]
[docs] def get_identifier(self):
"""Returns the name of the dataset."""
return 'PermutedMNIST'
if __name__ == '__main__':
pass