Source code for hypnettorch.data.mnist_data

#!/usr/bin/env python3
# Copyright 2018 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title           :mnist_data.py
# @author          :ch
# @contact         :henningc@ethz.ch
# @created         :08/08/2018
# @version         :1.0
# @python_version  :3.6.6
"""
MNIST Dataset
-------------

The module :mod:`data.mnist_data` contains a handler for the MNIST dataset.

The implementation is based on an earlier implementation of a class I used in
another project:

    https://git.io/fNyQL

Information about the dataset can be retrieved from:

    http://yann.lecun.com/exdb/mnist/
"""

import os
import struct
import numpy as np
import time
import _pickle as pickle
import urllib.request
import gzip
import matplotlib.pyplot as plt
from warnings import warn

from hypnettorch.data.cifar10_data import CIFAR10Data
from hypnettorch.data.dataset import Dataset

[docs]class MNISTData(Dataset): """An instance of the class shall represent the MNIST dataset. The constructor checks whether the dataset has been read before (a pickle dump has been generated). If so, it reads the dump. Otherwise, it reads the data from scratch and creates a dump for future usage. Note: By default, input samples are provided in a range of ``[0, 1]``. Args: data_path (str): Where should the dataset be read from? If not existing, the dataset will be downloaded into this folder. use_one_hot (bool): Whether the class labels should be represented in a one-hot encoding. validation_size (int): The number of validation samples. Validation samples will be taking from the training set (the first :math:`n` samples). use_torch_augmentation (bool): Apply data augmentation to inputs when calling method :meth:`data.dataset.Dataset.input_to_torch_tensor`. The augmentation will withening the inputs according to training image statistics (mean: ``0.1307``, std: ``0.3081``). In training mode, it will additionally apply random crops. Note: If activated, the statistics of test samples are changed as a normalization is applied. """ _DOWNLOAD_PATH = 'http://yann.lecun.com/exdb/mnist/' _TRAIN_IMGS_FN = 'train-images-idx3-ubyte.gz' _TRAIN_LBLS_FN = 'train-labels-idx1-ubyte.gz' _TEST_IMGS_FN = 't10k-images-idx3-ubyte.gz' _TEST_LBLS_FN = 't10k-labels-idx1-ubyte.gz' # In which file do we dump the dataset, to allow a faster readout next # time? _MNIST_DATA_DUMP = 'mnist_dataset.pickle' # In which subfolder of the datapath should the data be stored. _SUBFOLDER = 'MNIST' def __init__(self, data_path, use_one_hot=False, validation_size=5000, use_torch_augmentation=False): super().__init__() start = time.time() print('Reading MNIST dataset ...') # Actual data path data_path = os.path.join(data_path, MNISTData._SUBFOLDER) if not os.path.exists(data_path): print('Creating directory "%s" ...' % (data_path)) os.makedirs(data_path) # If data has been processed before. build_from_scratch = True dump_fn = os.path.join(data_path, MNISTData._MNIST_DATA_DUMP) if os.path.isfile(dump_fn): build_from_scratch = False with open(dump_fn, 'rb') as f: self._data = pickle.load(f) if self._data['is_one_hot'] != use_one_hot: reverse = True if use_one_hot: reverse = False self._data['is_one_hot'] = use_one_hot self._data['out_data'] = self._to_one_hot( self._data['out_data'], reverse=reverse) self._data['out_shape'] = [self._data['out_data'].shape[1]] # DELETEME A previous version of the dataloader stored the # validation set in the pickle file. Hence, this line ensures # downwards compatibility. if self.num_val_samples != 0: build_from_scratch = True self._data['val_inds'] = None if build_from_scratch: train_images_fn = os.path.join(data_path, MNISTData._TRAIN_IMGS_FN) train_labels_fn = os.path.join(data_path, MNISTData._TRAIN_LBLS_FN) test_images_fn = os.path.join(data_path, MNISTData._TEST_IMGS_FN) test_labels_fn = os.path.join(data_path, MNISTData._TEST_LBLS_FN) if not os.path.exists(train_images_fn): print('Downloading training images ...') urllib.request.urlretrieve(MNISTData._DOWNLOAD_PATH + \ MNISTData._TRAIN_IMGS_FN, \ train_images_fn) ## Extract downloaded images. #with gzip.open(train_images_fn, 'rb') as f_in: # with open(os.path.splitext(train_images_fn)[0], \ # 'wb') as f_out: # shutil.copyfileobj(f_in, f_out) if not os.path.exists(train_labels_fn): print('Downloading training labels ...') urllib.request.urlretrieve(MNISTData._DOWNLOAD_PATH + \ MNISTData._TRAIN_LBLS_FN, \ train_labels_fn) if not os.path.exists(test_images_fn): print('Downloading test images ...') urllib.request.urlretrieve(MNISTData._DOWNLOAD_PATH + \ MNISTData._TEST_IMGS_FN, \ test_images_fn) if not os.path.exists(test_labels_fn): print('Downloading test labels ...') urllib.request.urlretrieve(MNISTData._DOWNLOAD_PATH + \ MNISTData._TEST_LBLS_FN, \ test_labels_fn) # read labels train_labels = MNISTData._read_labels(train_labels_fn) test_labels = MNISTData._read_labels(test_labels_fn) # read images train_inputs = MNISTData._read_images(train_images_fn) test_inputs = MNISTData._read_images(test_images_fn) assert(train_labels.shape[0] == train_inputs.shape[0]) assert(test_labels.shape[0] == test_inputs.shape[0]) # Note, we ignore a possible validation set here on purpose, as it # should not be part of the pickle (see below). train_inds = np.arange(train_labels.size) test_inds = np.arange(train_labels.size, train_labels.size + test_labels.size) labels = np.concatenate([train_labels, test_labels]) images = np.concatenate([train_inputs, test_inputs], axis=0) labels = np.reshape(labels, (-1, 1)) # Scale images into a range between 0 and 1. images = images / 255 # Bring these raw readings into the internal structure of the # Dataset class. self._data['classification'] = True self._data['sequence'] = False self._data['num_classes'] = 10 self._data['is_one_hot'] = use_one_hot self._data['in_data'] = images self._data['in_shape'] = [28, 28, 1] self._data['out_shape'] = [10 if use_one_hot else 1] self._data['train_inds'] = train_inds self._data['test_inds'] = test_inds if use_one_hot: labels = self._to_one_hot(labels) self._data['out_data'] = labels # Save read dataset to allow faster reading in future. with open(dump_fn, 'wb') as f: pickle.dump(self._data, f) # After writing the pickle, correct train and validation set indices. if validation_size > 0: train_inds_orig = self._data['train_inds'] if validation_size >= train_inds_orig.size: raise ValueError('Validation set must contain less than %d ' \ % (train_inds_orig.size) + 'samples!') val_inds = np.arange(validation_size) train_inds = np.arange(validation_size, train_inds_orig.size) self._data['train_inds'] = train_inds self._data['val_inds'] = val_inds # Initialize PyTorch data augmentation. self._augment_inputs = False if use_torch_augmentation: self._augment_inputs = True self._train_transform, self._test_transform = \ MNISTData.torch_input_transforms(use_random_hflips=False) end = time.time() print('Elapsed time to read dataset: %f sec' % (end-start)) @staticmethod def _read_labels(filename): """Reading a set of labels from a file. Args: filename: Path and name of the byte file that contains the labels. Returns: The labels as a 1D numpy array. """ assert(os.path.isfile(filename)) print('Reading labels from %s.' % filename) with gzip.open(filename, "rb") as f: # Skip magic number. f.read(4) # Get number of labels in this file. num = int.from_bytes(f.read(4), byteorder='big') print('Number of labels in current file: %d' % num) # The rest of the file are "num" bytes, each byte encoding a label. labels = np.array(struct.unpack('%dB' % num, f.read(num))) return labels @staticmethod def _read_images(filename): """Reading a set of images from a file. Args: filename: Path and name of the byte file that contains the images. Returns: The images stacked in a 2D array, where each row is one image. """ assert(os.path.isfile(filename)) print('Reading images from %s.' % filename) with gzip.open(filename, 'rb') as f: # Skip magic number f.read(4) # Get number of images in this file. num = int.from_bytes(f.read(4), byteorder='big') print('Number of images in current file: %d' % num) # Get number of rows and columns. rows = int.from_bytes(f.read(4), byteorder='big') cols = int.from_bytes(f.read(4), byteorder='big') # The rest of the file consists of pure image data, each pixel # value encoded as a byte. num_rem_bytes = num * rows * cols images = np.array(struct.unpack('%dB' % num_rem_bytes, f.read(num_rem_bytes))) images = np.reshape(images, (-1, rows * cols)) return images
[docs] @staticmethod def plot_sample(image, label=None, interactive=False, file_name=None): """Plot a single MNIST sample. This method is thought to be helpful for evaluation and debugging purposes. .. deprecated:: 1.0 Please use method :meth:`data.dataset.Dataset.plot_samples` instead. Args: image: A single MNIST image (given as 1D vector). label: The label of the given image. interactive: Turn on interactive mode. Thus program will run in background while figure is displayed. The figure will be displayed until another one is displayed, the user closes it or the program has terminated. If this option is deactivated, the program will freeze until the user closes the figure. file_name: (optional) If a file name is provided, then the image will be written into a file instead of plotted to the screen. """ warn('Please use method "plot_samples" instead.', DeprecationWarning) if label is None: plt.title("MNIST Sample") else: plt.title('Label of shown sample: %d' % label) plt.axis('off') if interactive: plt.ion() plt.imshow(np.reshape(image, (28, 28))) if file_name is not None: plt.savefig(file_name, bbox_inches='tight') else: plt.show()
[docs] def get_identifier(self): """Returns the name of the dataset.""" return 'MNIST'
[docs] def input_to_torch_tensor(self, x, device, mode='inference', force_no_preprocessing=False, sample_ids=None): """This method can be used to map the internal numpy arrays to PyTorch tensors. Note, this method has been overwritten from the base class. If enabled via constructor option ``use_torch_augmentation``, input images are preprocessed. Preprocessing involves normalization and (for training mode) random perturbations. Args: (....): See docstring of method :meth:`data.dataset.Dataset.input_to_torch_tensor`. Returns: (torch.Tensor): The given input ``x`` as PyTorch tensor. """ if self._augment_inputs and not force_no_preprocessing: if mode == 'inference': transform = self._test_transform elif mode == 'train': transform = self._train_transform else: raise ValueError('"%s" not a valid value for argument "mode".' % mode) return CIFAR10Data.torch_augment_images(x, device, transform, img_shape=self.in_shape) else: return Dataset.input_to_torch_tensor(self, x, device, mode=mode, force_no_preprocessing=force_no_preprocessing, sample_ids=sample_ids)
def _plot_sample(self, fig, inner_grid, num_inner_plots, ind, inputs, outputs=None, predictions=None): """Implementation of abstract method :meth:`data.dataset.Dataset._plot_sample`. """ ax = plt.Subplot(fig, inner_grid[0]) if outputs is None: ax.set_title("MNIST Sample") else: assert(np.size(outputs) == 1) label = np.asscalar(outputs) if predictions is None: ax.set_title('MNIST sample with\nlabel: %d' % label) else: if np.size(predictions) == self.num_classes: pred_label = np.argmax(predictions) else: pred_label = np.asscalar(predictions) ax.set_title('MNIST sample with\nlabel: %d (prediction: %d)' % (label, pred_label)) #plt.subplots_adjust(wspace=0.5, hspace=0.4) ax.set_axis_off() ax.imshow(np.squeeze(np.reshape(inputs, self.in_shape))) fig.add_subplot(ax) if num_inner_plots == 2: ax = plt.Subplot(fig, inner_grid[1]) ax.set_title('Predictions') bars = ax.bar(range(self.num_classes), np.squeeze(predictions)) ax.set_xticks(range(self.num_classes)) if outputs is not None: bars[int(label)].set_color('r') fig.add_subplot(ax) def _plot_config(self, inputs, outputs=None, predictions=None): """Re-Implementation of method :meth:`data.dataset.Dataset._plot_config`. This method has been overriden to ensure, that there are 2 subplots, in case the predictions are given. """ plot_configs = super()._plot_config(inputs, outputs=outputs, predictions=predictions) if predictions is not None and \ np.shape(predictions)[1] == self.num_classes: plot_configs['outer_hspace'] = 0.6 plot_configs['inner_hspace'] = 0.4 plot_configs['num_inner_rows'] = 2 #plot_configs['num_inner_cols'] = 1 plot_configs['num_inner_plots'] = 2 return plot_configs
[docs] @staticmethod def torch_input_transforms(use_random_hflips=False): """Get data augmentation pipelines for MNIST inputs. Args: use_random_hflips (bool): Also use random horizontal flips during training. Note: That should not be ``True`` for MNIST, since digits loose there meaning when flipped. Returns: (tuple): Tuple containing: - **train_transform**: A transforms pipeline that applies random transformations and normalizes the image. - **test_transform**: Similar to train_transform, but no random transformations are applied. """ import torchvision.transforms as transforms normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,)) train_transform = transforms.Compose([ transforms.ToPILImage(), transforms.RandomCrop(size=[28,28], padding=4)] + ([transforms.RandomHorizontalFlip()] if use_random_hflips else []) + [transforms.ToTensor(), normalize ]) test_transform = transforms.Compose([ transforms.ToPILImage(), transforms.ToTensor(), normalize, ]) return train_transform, test_transform
if __name__ == '__main__': pass