import torchvision
import torch
import numpy as np
import ssl
from clustpy.data._utils import _get_download_dir, _load_color_image_data
"""
Load torchvision datasets
"""
def _get_data_and_labels(dataset: torchvision.datasets.VisionDataset, image_size: tuple) -> (
torch.Tensor, torch.Tensor):
"""
Extract data and labels from a torchvision dataset object.
Parameters
----------
dataset : torchvision.datasets.VisionDataset
The torchvision dataset object
image_size : tuple
for some datasets (e.g., GTSRB) the images of various sizes must be converted into a coherent size.
The tuple equals (width, height) of the images
Returns
-------
data, labels : (torch.Tensor, torch.Tensor)
the data torch tensor, the labels torch tensor
"""
if hasattr(dataset, "data"):
# USPS, MNIST, ... use data parameter
data = dataset.data
if hasattr(dataset, "targets"):
# USPS, MNIST, ... use targets
labels = dataset.targets
else:
# SVHN, STL10, ... use labels
labels = dataset.labels
else:
# GTSRB only gives path to images
labels = []
data_list = []
for path, label in dataset._samples:
labels.append(label)
image_data = _load_color_image_data(path, image_size)
data_list.append(image_data)
# Convert data form list to numpy array
data = np.array(data_list)
labels = np.array(labels)
if type(data) is np.ndarray:
# Transform numpy arrays to torch tensors. Needs to be done for eg USPS
data = torch.from_numpy(data)
labels = torch.from_numpy(np.array(labels))
return data, labels
def _load_torch_image_data(data_source: torchvision.datasets.VisionDataset, subset: str, flatten: bool,
normalize_channels: bool, uses_train_param: bool, downloads_path: str,
is_color_channel_last: bool, image_size: tuple = None) -> (np.ndarray, np.ndarray):
"""
Helper function to load a data set from the torchvision package.
All data sets will be returned as a two-dimensional tensor, created out of the HWC (height, width, color channels) image representation.
Parameters
----------
data_source : torchvision.datasets.VisionDataset
the data source from torchvision.datasets
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data
flatten : bool
should the image data be flatten, i.e. should the format be changed to a (N x d) array.
If false, color images will be returned in the CHW format
normalize_channels : bool
normalize each color-channel of the images
uses_train_param : bool
is the test/train parameter called 'train' or 'split' in the data loader. uses_train_param = True corresponds to 'train'
downloads_path : str
path to the directory where the data is stored
is_color_channel_last : bool
if true, the color channels should be in the last dimension, known as HWC representation. Alternatively the color channel can be at the first position, known as CHW representation.
Only relevant for color images -> Should be None for grayscale images
image_size : tuple
for some datasets (e.g., GTSRB) the images of various sizes must be converted into a coherent size.
The tuple equals (width, height) of the images (default: None)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array, the labels numpy array
"""
subset = subset.lower()
assert subset in ["all", "train",
"test"], "subset must match 'all', 'train' or 'test'. Your input {0}".format(subset)
# Get data from source
default_ssl = ssl._create_default_https_context
ssl._create_default_https_context = ssl._create_unverified_context
if subset == "all" or subset == "train":
# Load training data
if uses_train_param:
trainset = data_source(root=_get_download_dir(downloads_path), train=True, download=True)
else:
trainset = data_source(root=_get_download_dir(downloads_path), split="train", download=True)
data, labels = _get_data_and_labels(trainset, image_size)
if subset == "all" or subset == "test":
# Load test data
if uses_train_param:
testset = data_source(root=_get_download_dir(downloads_path), train=False, download=True)
else:
testset = data_source(root=_get_download_dir(downloads_path), split="test", download=True)
data_test, labels_test = _get_data_and_labels(testset, image_size)
if subset == "all":
# Add to train data
data = torch.cat([data, data_test], dim=0)
labels = torch.cat([labels, labels_test], dim=0)
else:
data = data_test
labels = labels_test
# Convert data to float and labels to int
data = data.float()
labels = labels.int()
ssl._create_default_https_context = default_ssl
# Check data dimensions
if data.dim() < 3 or data.dim() > 5:
raise Exception(
"Number of dimensions for torchvision data sets should be 3, 4 or 5. Here dim={0}".format(data.dim()))
# Normalize and flatten
data = _torch_normalize_and_flatten(data, flatten, normalize_channels, is_color_channel_last)
# Move data to CPU
data_cpu = data.detach().cpu().numpy()
labels_cpu = labels.detach().cpu().numpy()
return data_cpu, labels_cpu
def _torch_normalize_and_flatten(data: torch.Tensor, flatten: bool, normalize_channels: bool,
is_color_channel_last: bool):
"""
Helper function to load a data set from the torchvision package.
All data sets will be returned as a two-dimensional tensor, created out of the HWC (height, width, color channels) image representation.
Parameters
----------
data : torch.Tensor
The torch data tensor
flatten : bool
should the image data be flatten, i.e. should the format be changed to a (N x d) array.
If false, color images will be returned in the CHW format
normalize_channels : bool
normalize each color-channel of the images
is_color_channel_last : bool
if true, the color channels should be in the last dimension, known as HWC representation. Alternatively the color channel can be at the first position, known as CHW representation.
Only relevant for color images -> Should be None for grayscale images
Returns
-------
data : torch.Tensor
The (non-)normalized and (non-)flatten data tensor
"""
# Channels can be normalized
if normalize_channels:
data = _torch_normalize_channels(data, is_color_channel_last)
# Flatten shape
if flatten:
data = _torch_flatten_shape(data, is_color_channel_last, normalize_channels)
elif not normalize_channels:
# Change image to CHW format
if data.dim() == 4 and is_color_channel_last: # equals 2d color images
# Change to CHW representation
data = data.permute(0, 3, 1, 2)
assert data.shape[1] == 3, "Colored image must consist of three channels not " + data.shape[1]
elif data.dim() == 5 and is_color_channel_last: # equals 3d color-images
# Change to CHWD representation
data = data.permute(0, 4, 1, 2, 3)
assert data.shape[1] == 3, "Colored image must consist of three channels not {0}".format(data.shape[1])
return data
def _torch_normalize_channels(data: torch.Tensor, is_color_channel_last: bool) -> torch.Tensor:
"""
Normalize the color channels of a torch dataset
Parameters
----------
data : torch.Tensor
The torch data tensor
is_color_channel_last : bool
if true, the color channels should be in the last dimension, known as HWC representation. Alternatively the color channel can be at the first position, known as CHW representation.
Only relevant for color images -> Should be None for grayscale images
Returns
-------
data : torch.Tensor
The normalized data tensor in CHW format
"""
if data.dim() == 3 or (data.dim() == 4 and is_color_channel_last is None):
# grayscale images (2d or 3d)
data_mean = [data.mean()]
data_std = [data.std()]
elif data.dim() == 4: # equals 2d color images
if is_color_channel_last:
# Change to CHW representation
data = data.permute(0, 3, 1, 2)
assert data.shape[1] == 3, "Colored image must consist of three channels not " + data.shape[1]
# color images
data_mean = data.mean([0, 2, 3])
data_std = data.std([0, 2, 3])
elif data.dim() == 5: # equals 3d color-images
if is_color_channel_last:
# Change to CHWD representation
data = data.permute(0, 4, 1, 2, 3)
assert data.shape[1] == 3, "Colored image must consist of three channels not {0}".format(data.shape[1])
# color images
data_mean = data.mean([0, 2, 3, 4])
data_std = data.std([0, 2, 3, 4])
normalize = torchvision.transforms.Normalize(data_mean, data_std)
data = normalize(data)
return data
def _torch_flatten_shape(data: torch.Tensor, is_color_channel_last: bool, normalize_channels: bool) -> torch.Tensor:
"""
Convert torch data tensor from image to numerical vector.
Parameters
----------
data : torch.Tensor
is_color_channel_last : bool
if true, the color channels should be in the last dimension, known as HWC representation. Alternatively the color channel can be at the first position, known as CHW representation.
Only relevant for color images -> Should be None for grayscale images
normalize_channels : bool
normalize each color-channel of the images
Returns
-------
data : torch.Tensor
The flatten data vector
"""
# Flatten shape
if data.dim() == 3:
data = data.reshape(-1, data.shape[1] * data.shape[2])
elif data.dim() == 4:
# In case of 3d grayscale image is_color_channel_last is None
if is_color_channel_last is not None and (not is_color_channel_last or normalize_channels):
# Change representation to HWC
data = data.permute(0, 2, 3, 1)
assert is_color_channel_last is None or data.shape[
3] == 3, "Colored image must consist of three channels not {0}".format(data.shape[3])
data = data.reshape(-1, data.shape[1] * data.shape[2] * data.shape[3])
elif data.dim() == 5:
if not is_color_channel_last or normalize_channels:
# Change representation to HWDC
data = data.permute(0, 2, 3, 4, 1)
assert data.shape[4] == 3, "Colored image must consist of three channels not {0}".format(data.shape[4])
data = data.reshape(-1, data.shape[1] * data.shape[2] * data.shape[3] * data.shape[4])
return data
[docs]def load_mnist(subset: str = "all", flatten: bool = True, normalize_channels: bool = False,
downloads_path: str = None) -> (np.ndarray, np.ndarray):
"""
Load the MNIST data set. It consists of 70000 28x28 grayscale images showing handwritten digits (0 to 9).
The data set is composed of 60000 training and 10000 test images.
N=70000, d=784, k=10.
Parameters
----------
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data (default: 'all')
flatten : bool
should the image data be flatten, i.e. should the format be changed to a (N x d) array (default: True)
normalize_channels : bool
normalize each color-channel of the images (default: False)
downloads_path : bool
path to the directory where the data is stored (default: None -> [USER]/Downloads/clustpy_datafiles)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array (70000 x 784), the labels numpy array (70000)
References
-------
https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST
"""
data, labels = _load_torch_image_data(torchvision.datasets.MNIST, subset, flatten, normalize_channels, True,
downloads_path, None)
return data, labels
[docs]def load_kmnist(subset: str = "all", flatten: bool = True, normalize_channels: bool = False,
downloads_path: str = None) -> (np.ndarray, np.ndarray):
"""
Load the Kuzushiji-MNIST data set. It consists of 70000 28x28 grayscale images showing Kanji characters.
It is composed of 10 different characters, each representing one column of hiragana.
The data set is composed of 60000 training and 10000 test images.
N=70000, d=784, k=10.
Parameters
----------
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data (default: 'all')
flatten : bool
should the image data be flatten, i.e. should the format be changed to a (N x d) array (default: True)
normalize_channels : bool
normalize each color-channel of the images (default: False)
downloads_path : str
path to the directory where the data is stored (default: None -> [USER]/Downloads/clustpy_datafiles)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array (70000 x 784), the labels numpy array (70000)
References
-------
https://pytorch.org/vision/stable/generated/torchvision.datasets.KMNIST.html#torchvision.datasets.KMNIST
"""
data, labels = _load_torch_image_data(torchvision.datasets.KMNIST, subset, flatten, normalize_channels, True,
downloads_path, None)
return data, labels
[docs]def load_fmnist(subset: str = "all", flatten: bool = True, normalize_channels: bool = False,
downloads_path: str = None) -> (np.ndarray, np.ndarray):
"""
Load the Fashion-MNIST data set. It consists of 70000 28x28 grayscale images showing articles from the Zalando online store.
Each sample belongs to one of 10 product groups.
The data set is composed of 60000 training and 10000 test images.
N=70000, d=784, k=10.
Parameters
----------
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data (default: 'all')
flatten : bool
should the image data be flatten, i.e. should the format be changed to a (N x d) array (default: True)
normalize_channels : bool
normalize each color-channel of the images (default: False)
downloads_path : str
path to the directory where the data is stored (default: None -> [USER]/Downloads/clustpy_datafiles)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array (70000 x 784), the labels numpy array (70000)
References
-------
https://pytorch.org/vision/stable/generated/torchvision.datasets.FashionMNIST.html#torchvision.datasets.FashionMNIST
"""
data, labels = _load_torch_image_data(torchvision.datasets.FashionMNIST, subset, flatten, normalize_channels, True,
downloads_path, None)
return data, labels
[docs]def load_usps(subset: str = "all", flatten: bool = True, normalize_channels: bool = False,
downloads_path: str = None) -> (np.ndarray, np.ndarray):
"""
Load the USPS data set. It consists of 9298 16x16 grayscale images showing handwritten digits (0 to 9).
The data set is composed of 7291 training and 2007 test images.
N=9298, d=256, k=10.
Parameters
----------
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data (default: 'all')
flatten : bool
should the image data be flatten, i.e. should the format be changed to a (N x d) array (default: True)
normalize_channels : bool
normalize each color-channel of the images (default: False)
downloads_path : str
path to the directory where the data is stored (default: None -> [USER]/Downloads/clustpy_datafiles)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array (9298 x 256), the labels numpy array (9298)
References
-------
https://pytorch.org/vision/stable/generated/torchvision.datasets.USPS.html#torchvision.datasets.USPS
"""
data, labels = _load_torch_image_data(torchvision.datasets.USPS, subset, flatten, normalize_channels, True,
downloads_path, None)
return data, labels
[docs]def load_cifar10(subset: str = "all", flatten: bool = True, normalize_channels: bool = False,
downloads_path: str = None) -> (np.ndarray, np.ndarray):
"""
Load the CIFAR10 data set. It consists of 60000 32x32 color images showing different objects.
The classes are airplane, automobile, bird, cat, deer, dog, frog, horse, ship and truck.
The data set is composed of 50000 training and 10000 test images.
N=60000, d=3072, k=10.
Parameters
----------
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data (default: 'all')
flatten : bool
should the image data be flatten, i.e. should the format be changed to a (N x d) array.
If false, the image will be returned in the CHW format (default: True)
normalize_channels : bool
normalize each color-channel of the images (default: False)
downloads_path : str
path to the directory where the data is stored (default: None -> [USER]/Downloads/clustpy_datafiles)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array (60000 x 3072), the labels numpy array (60000)
References
-------
https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html#torchvision.datasets.CIFAR10
"""
data, labels = _load_torch_image_data(torchvision.datasets.CIFAR10, subset, flatten, normalize_channels, True,
downloads_path, True)
return data, labels
[docs]def load_svhn(subset: str = "all", flatten: bool = True, normalize_channels: bool = False,
downloads_path: str = None) -> (np.ndarray, np.ndarray):
"""
Load the SVHN data set. It consists of 99289 32x32 color images showing house numbers (0 to 9).
The data set is composed of 73257 training and 26032 test images.
N=99289, d=3072, k=10.
Parameters
----------
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data (default: 'all')
flatten : bool
should the image data be flatten, i.e. should the format be changed to a (N x d) array.
If false, the image will be returned in the CHW format (default: True)
normalize_channels : bool
normalize each color-channel of the images (default: False)
downloads_path : str
path to the directory where the data is stored (default: None -> [USER]/Downloads/clustpy_datafiles)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array (99289 x 3072), the labels numpy array (99289)
References
-------
https://pytorch.org/vision/stable/generated/torchvision.datasets.SVHN.html#torchvision.datasets.SVHN
"""
data, labels = _load_torch_image_data(torchvision.datasets.SVHN, subset, flatten, normalize_channels, False,
downloads_path, False)
return data, labels
[docs]def load_stl10(subset: str = "all", flatten: bool = True, normalize_channels: bool = False,
downloads_path: str = None) -> (np.ndarray, np.ndarray):
"""
Load the STL10 data set. It consists of 13000 96x96 color images showing different objects.
The classes are airplane, bird, car, cat, deer, dog, horse, monkey, ship and truck.
The data set is composed of 5000 training and 8000 test images.
N=13000, d=27648, k=10.
Parameters
----------
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data (default: 'all')
flatten : bool
should the image data be flatten, i.e. should the format be changed to a (N x d) array.
If false, the image will be returned in the CHW format (default: True)
normalize_channels : bool
normalize each color-channel of the images (default: False)
downloads_path : str
path to the directory where the data is stored (default: None -> [USER]/Downloads/clustpy_datafiles)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array (13000 x 27648), the labels numpy array (13000)
References
-------
https://pytorch.org/vision/stable/generated/torchvision.datasets.STL10.html#torchvision.datasets.STL10
"""
data, labels = _load_torch_image_data(torchvision.datasets.STL10, subset, flatten, normalize_channels, False,
downloads_path, False)
return data, labels
[docs]def load_gtsrb(subset: str = "all", image_size: tuple = (32, 32), flatten: bool = True,
normalize_channels: bool = False, downloads_path: str = None) -> (np.ndarray, np.ndarray):
"""
Load the GTSRB (German Traffic Sign Recognition Benchmark) data set. It consists of 39270 color images showing 43 different traffic signs.
Example classes are: stop sign, speed limit 50 sign, speed limit 70 sign, construction site sign and many others.
The data set is composed of 26640 training and 12630 test images.
N=39270, d=image_size[0]*image_size[1]*3, k=43.
Parameters
----------
subset : str
can be 'all', 'test' or 'train'. 'all' combines test and train data (default: 'all')
image_size : tuple
the images of various sizes must be converted into a coherent size.
The tuple equals (width, height) of the images (default: (32, 32))
flatten : bool
should the image data be flatten, i.e. should the format be changed to a (N x d) array.
If false, the image will be returned in the CHW format (default: True)
normalize_channels : bool
normalize each color-channel of the images (default: False)
downloads_path : str
path to the directory where the data is stored (default: None -> [USER]/Downloads/clustpy_datafiles)
Returns
-------
data, labels : (np.ndarray, np.ndarray)
the data numpy array (39270 x image_size[0]*image_size[1]*3), the labels numpy array (20580)
References
-------
https://pytorch.org/vision/stable/generated/torchvision.datasets.GTSRB.html#torchvision.datasets.GTSRB
and
https://benchmark.ini.rub.de/
"""
data, labels = _load_torch_image_data(torchvision.datasets.GTSRB, subset, flatten, normalize_channels, False,
downloads_path, True, image_size)
return data, labels