import torchvision
import torch
import numpy as np
import ssl
from clustpy.data._utils import _get_download_dir
"""
Load torchvision datasets
"""
def _load_torch_image_data(data_source: torchvision.datasets.VisionDataset, subset: str, normalize_channels: bool,
uses_train_param: bool, downloads_path: str, is_color_channel_last: bool) -> (
np.ndarray, np.ndarray):
"""
Helper function to load a data set from the torchvision package.
All data sets will be returned as a two-dimensional tensor, created out of the HWC (height, width, color channels) image representation.
Parameters
----------
data_source : torchvision.datasets.VisionDataset
the data source from torchvision.datasets
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data
normalize_channels : bool
normalize each color-channel of the images
uses_train_param : bool
is the test/train parameter called 'train' or 'split' in the data loader. uses_train_param = True corresponds to 'train'
downloads_path : str
path to the directory where the data is stored
is_color_channel_last : bool
if true, the color channels should be in the last dimension, known as HWC representation. Alternatively the color channel can be at the first position, known as CHW representation.
Only relevant for color images -> Should be None for grayscale images
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array, the labels numpy array
"""
subset = subset.lower()
assert subset in ["all", "train",
"test"], "subset must match 'all', 'train' or 'test'. Your input {0}".format(subset)
# Get data from source
default_ssl = ssl._create_default_https_context
ssl._create_default_https_context = ssl._create_unverified_context
if subset == "all" or subset == "train":
# Load training data
if uses_train_param:
trainset = data_source(root=_get_download_dir(downloads_path), train=True, download=True)
else:
trainset = data_source(root=_get_download_dir(downloads_path), split="train", download=True)
data = trainset.data
if hasattr(trainset, "targets"):
# USPS, MNIST, ... use targets
labels = trainset.targets
else:
# SVHN, STL10, ... use labels
labels = trainset.labels
if type(data) is np.ndarray:
# Transform numpy arrays to torch tensors. Needs to be done for eg USPS
data = torch.from_numpy(data)
labels = torch.from_numpy(np.array(labels))
if subset == "all" or subset == "test":
# Load test data
if uses_train_param:
testset = data_source(root=_get_download_dir(downloads_path), train=False, download=True)
else:
testset = data_source(root=_get_download_dir(downloads_path), split="test", download=True)
data_test = testset.data
if hasattr(testset, "targets"):
# USPS, MNIST, ... use targets
labels_test = testset.targets
else:
# SVHN, STL10, ... use labels
labels_test = testset.labels
if type(data_test) is np.ndarray:
# Transform numpy arrays to torch tensors. Needs to be done for eg USPS
data_test = torch.from_numpy(data_test)
labels_test = torch.from_numpy(np.array(labels_test))
if subset == "all":
# Add to train data
data = torch.cat([data, data_test], dim=0)
labels = torch.cat([labels, labels_test], dim=0)
else:
data = data_test
labels = labels_test
# Convert data to float and labels to int
data = data.float()
labels = labels.int()
ssl._create_default_https_context = default_ssl
# Check data dimensions
if data.dim() < 3 or data.dim() > 5:
raise Exception(
"Number of dimensions for torchvision data sets should be 3, 4 or 5. Here dim={0}".format(data.dim()))
# Channels can be normalized
if normalize_channels:
data = _torch_normalize_channels(data, is_color_channel_last)
# Flatten shape
data = _torch_flatten_shape(data, is_color_channel_last, normalize_channels)
# Move data to CPU
data_cpu = data.detach().cpu().numpy()
labels_cpu = labels.detach().cpu().numpy()
return data_cpu, labels_cpu
def _torch_normalize_channels(data: torch.Tensor, is_color_channel_last: bool) -> torch.Tensor:
"""
Normalize the color channels of a torch dataset
Parameters
----------
data : torch.Tensor
The torch data tensor
is_color_channel_last : bool
if true, the color channels should be in the last dimension, known as HWC representation. Alternatively the color channel can be at the first position, known as CHW representation.
Only relevant for color images -> Should be None for grayscale images
Returns
-------
The normalized data tensor
"""
if data.dim() == 3 or (data.dim() == 4 and is_color_channel_last is None):
# grayscale images (2d or 3d)
data_mean = [data.mean()]
data_std = [data.std()]
elif data.dim() == 4: # equals 2d color images
if is_color_channel_last:
# Change to CHW representation
data = data.permute(0, 3, 1, 2)
assert data.shape[1] == 3, "Colored image must consist of three channels not " + data.shape[1]
# color images
data_mean = data.mean([0, 2, 3])
data_std = data.std([0, 2, 3])
elif data.dim() == 5: # equals 3d color-images
if is_color_channel_last:
# Change to CHWD representation
data = data.permute(0, 4, 1, 2, 3)
assert data.shape[1] == 3, "Colored image must consist of three channels not {0}".format(data.shape[1])
# color images
data_mean = data.mean([0, 2, 3, 4])
data_std = data.std([0, 2, 3, 4])
normalize = torchvision.transforms.Normalize(data_mean, data_std)
data = normalize(data)
return data
def _torch_flatten_shape(data: torch.Tensor, is_color_channel_last: bool, normalize_channels: bool):
"""
Convert torch data tensor from image to numerical vector.
Parameters
----------
data : torch.Tensor
is_color_channel_last : bool
if true, the color channels should be in the last dimension, known as HWC representation. Alternatively the color channel can be at the first position, known as CHW representation.
Only relevant for color images -> Should be None for grayscale images
normalize_channels : bool
normalize each color-channel of the images
Returns
-------
The flatten data vector
"""
# Flatten shape
if data.dim() == 3:
data = data.reshape(-1, data.shape[1] * data.shape[2])
elif data.dim() == 4:
# In case of 3d grayscale image is_color_channel_last is None
if is_color_channel_last is not None and (not is_color_channel_last or normalize_channels):
# Change representation to HWC
data = data.permute(0, 2, 3, 1)
assert is_color_channel_last is None or data.shape[3] == 3, "Colored image must consist of three channels not {0}".format(data.shape[3])
data = data.reshape(-1, data.shape[1] * data.shape[2] * data.shape[3])
elif data.dim() == 5:
if not is_color_channel_last or normalize_channels:
# Change representation to HWDC
data = data.permute(0, 2, 3, 4, 1)
assert data.shape[4] == 3, "Colored image must consist of three channels not {0}".format(data.shape[4])
data = data.reshape(-1, data.shape[1] * data.shape[2] * data.shape[3] * data.shape[4])
return data
[docs]def load_mnist(subset: str = "all", normalize_channels: bool = False, downloads_path: str = None) -> (
np.ndarray, np.ndarray):
"""
Load the MNIST data set. It consists of 70000 28x28 grayscale images showing handwritten digits (0 to 9).
The data set is composed of 60000 training and 10000 test images.
N=70000, d=784, k=10.
Parameters
----------
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data (default: 'all')
normalize_channels : bool
normalize each color-channel of the images (default: False)
downloads_path : bool
path to the directory where the data is stored (default: None -> [USER]/Downloads/clustpy_datafiles)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array (70000 x 784), the labels numpy array (70000)
References
-------
https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST
"""
data, labels = _load_torch_image_data(torchvision.datasets.MNIST, subset, normalize_channels, True,
downloads_path, None)
return data, labels
[docs]def load_kmnist(subset: str = "all", normalize_channels: bool = False, downloads_path: str = None) -> (
np.ndarray, np.ndarray):
"""
Load the Kuzushiji-MNIST data set. It consists of 70000 28x28 grayscale images showing Kanji characters.
It is composed of 10 different characters, each representing one column of hiragana.
The data set is composed of 60000 training and 10000 test images.
N=70000, d=784, k=10.
Parameters
----------
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data (default: 'all')
normalize_channels : bool
normalize each color-channel of the images (default: False)
downloads_path : str
path to the directory where the data is stored (default: None -> [USER]/Downloads/clustpy_datafiles)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array (70000 x 784), the labels numpy array (70000)
References
-------
https://pytorch.org/vision/stable/generated/torchvision.datasets.KMNIST.html#torchvision.datasets.KMNIST
"""
data, labels = _load_torch_image_data(torchvision.datasets.KMNIST, subset, normalize_channels, True,
downloads_path, None)
return data, labels
[docs]def load_fmnist(subset: str = "all", normalize_channels: bool = False, downloads_path: str = None) -> (
np.ndarray, np.ndarray):
"""
Load the Fashion-MNIST data set. It consists of 70000 28x28 grayscale images showing articles from the Zalando online store.
Each sample belongs to one of 10 product groups.
The data set is composed of 60000 training and 10000 test images.
N=70000, d=784, k=10.
Parameters
----------
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data (default: 'all')
normalize_channels : bool
normalize each color-channel of the images (default: False)
downloads_path : str
path to the directory where the data is stored (default: None -> [USER]/Downloads/clustpy_datafiles)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array (70000 x 784), the labels numpy array (70000)
References
-------
https://pytorch.org/vision/stable/generated/torchvision.datasets.FashionMNIST.html#torchvision.datasets.FashionMNIST
"""
data, labels = _load_torch_image_data(torchvision.datasets.FashionMNIST, subset, normalize_channels, True,
downloads_path, None)
return data, labels
[docs]def load_usps(subset: str = "all", normalize_channels: bool = False, downloads_path: str = None) -> (
np.ndarray, np.ndarray):
"""
Load the USPS data set. It consists of 9298 16x16 grayscale images showing handwritten digits (0 to 9).
The data set is composed of 7291 training and 2007 test images.
N=9298, d=256, k=10.
Parameters
----------
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data (default: 'all')
normalize_channels : bool
normalize each color-channel of the images (default: False)
downloads_path : str
path to the directory where the data is stored (default: None -> [USER]/Downloads/clustpy_datafiles)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array (9298 x 256), the labels numpy array (9298)
References
-------
https://pytorch.org/vision/stable/generated/torchvision.datasets.USPS.html#torchvision.datasets.USPS
"""
data, labels = _load_torch_image_data(torchvision.datasets.USPS, subset, normalize_channels, True,
downloads_path, None)
return data, labels
[docs]def load_cifar10(subset: str = "all", normalize_channels: bool = False, downloads_path: str = None) -> (
np.ndarray, np.ndarray):
"""
Load the CIFAR10 data set. It consists of 60000 32x32 color images showing different objects.
The classes are airplane, automobile, bird, cat, deer, dog, frog, horse, ship and truck.
The data set is composed of 50000 training and 10000 test images.
N=60000, d=3072, k=10.
Parameters
----------
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data (default: 'all')
normalize_channels : bool
normalize each color-channel of the images (default: False)
downloads_path : str
path to the directory where the data is stored (default: None -> [USER]/Downloads/clustpy_datafiles)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array (60000 x 3072), the labels numpy array (60000)
References
-------
https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html#torchvision.datasets.CIFAR10
"""
data, labels = _load_torch_image_data(torchvision.datasets.CIFAR10, subset, normalize_channels, True,
downloads_path, True)
return data, labels
[docs]def load_svhn(subset: str = "all", normalize_channels: bool = False, downloads_path: str = None) -> (
np.ndarray, np.ndarray):
"""
Load the SVHN data set. It consists of 99289 32x32 color images showing house numbers (0 to 9).
The data set is composed of 73257 training and 26032 test images.
N=99289, d=3072, k=10.
Parameters
----------
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data (default: 'all')
normalize_channels : bool
normalize each color-channel of the images (default: False)
downloads_path : str
path to the directory where the data is stored (default: None -> [USER]/Downloads/clustpy_datafiles)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array (99289 x 3072), the labels numpy array (99289)
References
-------
https://pytorch.org/vision/stable/generated/torchvision.datasets.SVHN.html#torchvision.datasets.SVHN
"""
data, labels = _load_torch_image_data(torchvision.datasets.SVHN, subset, normalize_channels, False,
downloads_path, False)
return data, labels
[docs]def load_stl10(subset: str = "all", normalize_channels: bool = False, downloads_path: str = None) -> (
np.ndarray, np.ndarray):
"""
Load the STL10 data set. It consists of 13000 96x96 color images showing different objects.
The classes are airplane, bird, car, cat, deer, dog, horse, monkey, ship and truck.
The data set is composed of 5000 training and 8000 test images.
N=13000, d=27648, k=10.
Parameters
----------
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data (default: 'all')
normalize_channels : bool
normalize each color-channel of the images (default: False)
downloads_path : str
path to the directory where the data is stored (default: None -> [USER]/Downloads/clustpy_datafiles)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array (13000 x 27648), the labels numpy array (13000)
References
-------
https://pytorch.org/vision/stable/generated/torchvision.datasets.STL10.html#torchvision.datasets.STL10
"""
data, labels = _load_torch_image_data(torchvision.datasets.STL10, subset, normalize_channels, False,
downloads_path, False)
return data, labels