Source code for hypnettorch.data.special.donuts

#!/usr/bin/env python3
# Copyright 2021 Francesco D'Angelo
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title          :data/special/donuts.py
# @author         :fd, ch
# @contact        :fdangelo@student.ethz.ch
# @created        :07/08/2021
# @version        :1.0
# @python_version :3.8.10
"""
2D Donut Dataset
^^^^^^^^^^^^^^^^

This data handler creates a synthetic toy problem comprising 2D annuli.
"""
import numpy as np
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt

from hypnettorch.data.dataset import Dataset

[docs]class Donuts(Dataset): """Donut dataset handler. Note, each donut prescribes a different class. Args: centers (tuple or list): List of tuples, each determining the center of a donut. radii (tuple or list): List of tuples, each tuple defines the inner and outer radius of a donut. num_train (int): Number of training samples per donut. num_test (int): Number of test samples per donut. use_one_hot (bool): Whether the class labels should be represented as a one-hot encoding. rseed (int): If ``None``, the current random state of numpy is used to generate the data. Otherwise, a new random state with the given seed is generated. """ def __init__(self, centers=((0, 0), (0, 0)), radii=((3,4), (9,10)), num_train=100, num_test=100, use_one_hot=True, rseed=42): super().__init__() if rseed is None: rand = np.random else: rand = np.random.RandomState(rseed) assert len(centers) == len(radii) for i in range(len(centers)): c = centers[i] r = radii[i] donut_train = Donuts.sample_annulus(c[0], c[1], r[0], r[1], num=num_train, rand=rand) donut_test = Donuts.sample_annulus(c[0], c[1], r[0], r[1], num=num_test, rand=rand) if i == 0: train_x = donut_train train_y = np.ones((num_train, 1)) * i test_x = donut_test test_y = np.ones((num_test, 1)) * i else: train_x = np.vstack([train_x, donut_train]) train_y = np.vstack([train_y, np.ones((num_train, 1)) * i]) test_x = np.vstack([test_x, donut_test]) test_y = np.vstack([test_y, np.ones((num_test, 1)) * i]) in_data = np.vstack([train_x, test_x]) out_data = np.vstack([train_y, test_y]) # Specify internal data structure. self._data['classification'] = True self._data['sequence'] = False self._data['in_data'] = in_data self._data['in_shape'] = [2] self._data['num_classes'] = len(centers) if use_one_hot: out_data = self._to_one_hot(out_data) self._data['out_data'] = out_data self._data['out_shape'] = [2] self._data['train_inds'] = np.arange(train_x.shape[0]) self._data['test_inds'] = np.arange(train_x.shape[0], train_x.shape[0] + test_x.shape[0])
[docs] def get_identifier(self): """Returns the name of the dataset.""" return 'DonutsDataset'
[docs] @staticmethod def sample_annulus(x_c, y_c, r_inner, r_outer, num=1, rand=None): r"""Sample uniformly from an annulus. Sample uniformly :math:`(x, y)` satisfiying: .. math:: (x-x_\text{c})^2 + (y-y_\text{c})^2 \leq r_\text{outer}^2 and .. math:: (x-x_\text{c})^2 + (y-y_\text{c})^2 > r_\text{inner}^2 Args: x_c (float): x-position of the center. y_c (float): y-position of the center. r_inner (float): Inner radius. r_outer (float): Outer radius. num (int): Number of samples to return. rand (numpy.random.RandomState, optional): Random state object used for sampling. Returns: (numpy.ndarray): Array of shape ``[num, 2]``. """ # The code is inspired by this thread: # https://stackoverflow.com/questions/47005884/random-point-inside-annulus-with-a-shifted-hole assert r_inner <= r_outer if rand is None: rand = np.random # Sample from a normal annulus with radii r_inner and r_outer. rad = np.sqrt(rand.uniform(low=r_inner ** 2, high=r_outer ** 2, size=num)) angle = rand.uniform(low=-np.pi, high=np.pi, size=num) x, y = rad * np.cos(angle) + x_c, rad * np.sin(angle) + y_c return np.hstack((x.reshape(-1, 1), y.reshape(-1, 1)))
def _plot_sample(self, fig, inner_grid, num_inner_plots, ind, inputs, outputs=None, predictions=None): """Not implemented""" raise NotImplementedError('TODO implement')
[docs] def plot_dataset(self, title, show=True, filename=None, interactive=False, figsize=(10, 6)): """Plot samples belonging to this dataset. Args: (....): See docstring of method :meth:`data.dataset.Dataset.plot_samples`. """ colors = ListedColormap(['#FF0000', '#0000FF']) # Create plot fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) x_train_0 = self.get_train_inputs() y_train_0 = self.get_train_outputs() x_test_0 = self.get_test_inputs() y_test_0 = self.get_test_outputs() ax.scatter(x_train_0[:, 0], x_train_0[:, 1], alpha=1, marker='o', c=np.argmax(y_train_0, 1), cmap=colors, edgecolors='k', s=50, label='train') ax.scatter(x_test_0[:, 0], x_test_0[:, 1], alpha=0.6, marker='s', c=np.argmax(y_test_0, 1), cmap=colors, edgecolors='k', s=50, label='test') plt.title(title, fontsize=30) plt.legend(loc=2, fontsize=30) plt.show()
if __name__ == '__main__': pass