#!/usr/bin/env python3
# Copyright 2019 Christian Henning
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# @title :utils/misc.py
# @author :ch
# @contact :henningc@ethz.ch
# @created :04/11/2019
# @version :1.0
# @python_version :3.6.7
"""
Miscellaneous Utilities
-----------------------
A collection of helper functions.
"""
import inspect
import matplotlib
import matplotlib.pyplot as plt
import math
from torch import nn
import torch
from warnings import warn
[docs]def init_params(weights, bias=None):
"""Initialize the weights and biases of a linear or (transpose) conv layer.
Note, the implementation is based on the method "reset_parameters()",
that defines the original PyTorch initialization for a linear or
convolutional layer, resp. The implementations can be found here:
https://git.io/fhnxV
https://git.io/fhnx2
.. deprecated:: 1.0
Please use function :func:`utils.torch_utils.init_params` instead.
Args:
weights: The weight tensor to be initialized.
bias (optional): The bias tensor to be initialized.
"""
warn('Function is deprecated. Use "utils.torch_utils.init_params" instead.',
DeprecationWarning)
nn.init.kaiming_uniform_(weights, a=math.sqrt(5))
if bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(bias, -bound, bound)
[docs]def str_to_ints(str_arg):
"""Helper function to convert a string which is a list of comma separated
integers into an actual list of integers.
Args:
str_arg: String containing list of comma-separated ints. For convenience
reasons, we allow the user to also pass single integers that a put
into a list of length 1 by this function.
Returns:
(list): List of integers.
"""
if isinstance(str_arg, int):
return [str_arg]
if len(str_arg) > 0:
return [int(s) for s in str_arg.split(',')]
else:
return []
[docs]def str_to_floats(str_arg):
"""Helper function to convert a string which is a list of comma separated
floats into an actual list of floats.
Args:
str_arg: String containing list of comma-separated floats. For
convenience reasons, we allow the user to also pass single float
that a put into a list of length 1 by this function.
Returns:
(list): List of floats.
"""
if isinstance(str_arg, float):
return [str_arg]
if len(str_arg) > 0:
return [float(s) for s in str_arg.split(',')]
else:
return []
[docs]def list_to_str(list_arg, delim=' '):
"""Convert a list of numbers into a string.
Args:
list_arg: List of numbers.
delim (optional): Delimiter between numbers.
Returns:
(str): List converted to string.
"""
ret = ''
for i, e in enumerate(list_arg):
if i > 0:
ret += delim
ret += str(e)
return ret
[docs]def str_to_act(act_str):
"""Convert the name of an activation function into the actual PyTorch
activation function.
Args:
act_str: Name of activation function (as defined by command-line
arguments).
Returns:
Torch activation function instance or ``None``, if ``linear`` is given.
"""
if act_str == 'linear':
act = None
elif act_str == 'sigmoid':
act = torch.nn.Sigmoid()
elif act_str == 'relu':
act = torch.nn.ReLU()
elif act_str == 'elu':
act = torch.nn.ELU()
elif act_str == 'tanh':
act = torch.nn.Tanh()
else:
raise Exception('Activation function %s unknown.' % act_str)
return act
[docs]def get_colorbrewer2_colors(family = 'Set2'):
"""Helper function that returns a list of color combinations
extracted from colorbrewer2.org.
Args:
(list): the color family from colorbrewer2.org to use.
"""
# https://colorbrewer2.org/#type=qualitative&scheme=Set1&n=7
if family == 'Set2':
return [
'#e41a1c',
'#377eb8',
'#4daf4a',
'#984ea3',
'#ff7f00',
'#ffff33',
'#a65628',
'#b3de69'
]
# https://colorbrewer2.org/#type=qualitative&scheme=Set3&n=8
elif family == 'Set3':
return [
'#8dd3c7',
'#ffffb3',
'#bebada',
'#fb8072',
'#80b1d3',
'#fdb462',
'#b3de69',
'#fccde5'
]
# https://colorbrewer2.org/#type=qualitative&scheme=Dark2&n=8
elif family == 'Dark2':
return [
'#1b9e77',
'#d95f02',
'#7570b3',
'#e7298a',
'#66a61e',
'#e6ab02',
'#a6761d',
'#666666'
]
# https://colorbrewer2.org/#type=qualitative&scheme=Pastel1&n=8
elif family == 'Pastel':
return [
'#fbb4ae',
'#b3cde3',
'#ccebc5',
'#decbe4',
'#fed9a6',
'#ffffcc',
'#e5d8bd'
]
else:
raise ValueError()
[docs]def repair_canvas_and_show_fig(fig, close=True):
"""If writing a figure to tensorboard via "add_figure" it might change the
canvas, such that our backend doesn't allow to show the figure anymore.
This method will generate a new canvas and replace the old one of the
given figure.
Args:
fig: The figure to be shown.
close: Whether the figure should be closed after it has been shown.
"""
tmp_fig = plt.figure()
tmp_manager = tmp_fig.canvas.manager
tmp_manager.canvas.figure = fig
fig.set_canvas(tmp_manager.canvas)
plt.close(tmp_fig.number)
plt.figure(fig.number)
plt.show()
if close:
plt.close(fig.number)
[docs]def get_default_args(func):
"""Get the default values of all keyword arguments for a given function.
Args:
func: A function handle.
Returns:
(dict): Dictionary with keyword argument names as keys and their
default value as values.
"""
# The code from this function has been copied from (accessed: 02/28/2020):
# https://stackoverflow.com/questions/12627118/get-a-function-arguments-default-value
#
# NOTE Our copyright and license does not apply for this function.
# We use this code WITHOUT ANY WARRANTIES.
#
# Instead, the code in this method is licensed under CC BY-SA 3.0:
# https://creativecommons.org/licenses/by-sa/3.0/
#
# The code stems from an answer by user "mgilson":
# https://stackoverflow.com/users/748858/mgilson
signature = inspect.signature(func)
return {
k: v.default for k, v in signature.parameters.items() \
if v.default is not inspect.Parameter.empty
}
if __name__ == '__main__':
pass