import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Colormap
from matplotlib.colors import Normalize
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from scipy import stats
import matplotlib.patches as mpatches
from sklearn.base import TransformerMixin
"""
Constants
"""
# Circle, Square, Diamond, Plus, X, Triangle down, Star, Pentagon, Triangle Up, Triangle left, Triangle right, Hexagon
_MARKERS = ("o", "s", "D", "P", "X", "v", "*", "p", "^", ">", "<", "h")
_MIN_OBJECTS_FOR_DENS_PLOT = 3
[docs]def plot_1d_data(X: np.ndarray, labels: np.ndarray = None, centers: np.ndarray = None, true_labels: np.ndarray = None,
show_legend: bool = True, show_plot: bool = True) -> None:
"""
Plot a one-dimensional data set.
Parameters
----------
X : np.ndarray
the given data set
labels : np.ndarray
The cluster labels. Specifies the color of the plotted objects. Can be None (default: None)
centers : np.ndarray
The cluster centers. Will be plotted as red dots labeled by the corresponding cluster id. Can be None (default: None)
true_labels : np.ndarray
The ground truth labels. Specifies the symbol of the plotted objects. Can be None (default: None)
show_legend : bool
Defines whether a legend should be shown (default: True)
show_plot : bool
Defines whether the plot should directly be plotted (default: True)
"""
assert X.ndim == 1 or X.shape[1] == 1, "Data must be 1-dimensional"
assert centers is None or centers.ndim == 1 or centers.shape[1] == 1, "Centers must be 1-dimensional"
# Optional: Get first column of data
if X.ndim == 2:
X = X[:, 0]
# fig, ax = plt.subplots(figsize=figsize)
min_value = np.min(X)
max_value = np.max(X)
plt.hlines(1, min_value, max_value) # Draw a horizontal line
y = np.ones(len(X))
plt.scatter(X, y, marker='|', s=500, c=labels) # Plot a line at each location specified in X
if centers is not None:
# Optional: Get first column of centers
if centers.ndim == 2:
centers = centers[:, 0]
yc = np.ones(len(centers))
plt.scatter(centers, yc, s=300, color="red", marker="x")
# plot one center text above line and next below ...
centers_order = np.argsort(centers)
centers_order = np.argsort(centers_order)
for j in range(len(centers)):
yt = 1.0005 if centers_order[j] % 2 == 0 else 0.9994
plt.text(centers[j], yt, str(j), weight="bold")
if true_labels is not None:
plt.hlines(1.001, min_value, max_value)
y_true = np.ones(len(X)) * 1.001
plt.scatter(X, y_true, marker='|', s=500, c=true_labels)
if show_legend and labels is not None:
unique_labels, cmap, norm = _get_cmap_and_norm(labels)
_add_legend(plt, unique_labels, cmap, norm)
if show_plot:
plt.show()
[docs]def plot_2d_data(X: np.ndarray, labels: np.ndarray = None, centers: np.ndarray = None, true_labels: np.ndarray = None,
show_legend: bool = True, scattersize: int = 10, equal_axis: bool = False, container: plt.Axes = plt,
show_plot: bool = True) -> None:
"""
Plot a two-dimensional data set.
Parameters
----------
X : np.ndarray
the given data set
labels : np.ndarray
The cluster labels. Specifies the color of the plotted objects. Can be None (default: None)
centers : np.ndarray
The cluster centers. Will be plotted as red dots labeled by the corresponding cluster id. Can be None (default: None)
true_labels : np.ndarray
The ground truth labels. Specifies the symbol of the plotted objects. Can be None (default: None)
show_legend : bool
Defines whether a legend should be shown (default: True)
scattersize : float
The size of the scatters (default: 10)
equal_axis : bool
Defines whether the axes are to be scaled to the same value range (default: False)
container : plt.Axes
The container to which the scatter plot is added.
If another container is defined, show_plot should usually be False (default: matplotlib.pyplot)
show_plot : bool
Defines whether the plot should directly be plotted (default: True)
"""
assert X.ndim == 2 or X.shape[1] == 2, "Data must be 2-dimensional"
if true_labels is None:
container.scatter(X[:, 0], X[:, 1], c=labels, s=scattersize)
else:
unique_true_labels = np.unique(true_labels)
# Change marker for true labels
for lab_index, true_lab in enumerate(unique_true_labels):
marker = _MARKERS[lab_index % len(_MARKERS)]
container.scatter(X[true_labels == true_lab, 0], X[true_labels == true_lab, 1], s=scattersize,
c=labels if labels is None else labels[true_labels == true_lab], marker=marker,
vmin=np.min(labels), vmax=np.max(labels))
if centers is not None:
container.scatter(centers[:, 0], centers[:, 1], s=scattersize * 1.5, color="red", marker="s")
for j in range(len(centers)):
container.text(centers[j, 0], centers[j, 1], str(j), weight="bold")
if equal_axis:
container.axis("equal")
if show_legend and labels is not None:
unique_labels, cmap, norm = _get_cmap_and_norm(labels)
_add_legend(container, unique_labels, cmap, norm)
if show_plot:
container.show()
[docs]def plot_3d_data(X: np.ndarray, labels: np.ndarray = None, centers: np.ndarray = None, true_labels: np.ndarray = None,
show_legend: bool = True, scattersize: int = 10, show_plot: bool = True) -> None:
"""
Plot a three-dimensional data set.
Parameters
----------
X : np.ndarray
the given data set
labels : np.ndarray
The cluster labels. Specifies the color of the plotted objects. Can be None (default: None)
centers : np.ndarray
The cluster centers. Will be plotted as red dots labeled by the corresponding cluster id. Can be None (default: None)
true_labels : np.ndarray
The ground truth labels. Specifies the symbol of the plotted objects. Can be None (default: None)
show_legend : bool
Defines whether a legend should be shown (default: True)
scattersize : float
The size of the scatters (default: 10)
show_plot : bool
Defines whether the plot should directly be plotted (default: True)
"""
assert X.ndim == 2 or X.shape[1] == 3, "Data must be 3-dimensional"
fig = plt.figure()
ax = Axes3D(fig) # fig.add_subplot(111, projection='3d')
if true_labels is None:
ax.scatter(X[:, 0], X[:, 1], zs=X[:, 2], zdir='z', s=scattersize, c=labels, alpha=0.8)
else:
unique_true_labels = np.unique(true_labels)
# Change marker for true labels
for lab_index, true_lab in enumerate(unique_true_labels):
marker = _MARKERS[lab_index % len(_MARKERS)]
ax.scatter(X[true_labels == true_lab, 0], X[true_labels == true_lab, 1],
zs=X[true_labels == true_lab, 2], zdir='z', s=scattersize,
c=labels if labels is None else labels[true_labels == true_lab],
marker=marker, vmin=np.min(labels), vmax=np.max(labels), alpha=0.8)
if centers is not None:
ax.scatter(centers[:, 0], centers[:, 1], zs=centers[:, 2], zdir='z', s=scattersize * 1.5, color="red",
marker="s")
for j in range(len(centers)):
ax.text(centers[j, 0], centers[j, 1], centers[j, 2], str(j), weight="bold")
if show_legend and labels is not None:
unique_labels, cmap, norm = _get_cmap_and_norm(labels)
_add_legend(fig, unique_labels, cmap, norm)
if show_plot:
plt.show()
plt.figure() # Create new figure for future plots
[docs]def plot_image(img_data: np.ndarray, black_and_white: bool = False, image_shape: tuple = None, max_value: float = None,
min_value: float = None, show_plot: bool = True) -> None:
"""
Plot an image.
Expects a color image to occur in the HWC representation (height, width, color channels).
Parameters
----------
img_data : np.ndarray
The image data
black_and_white : bool
Specifies whether the image should be plotted in grayscale colors. Only relevant for images without color channels (default: False)
image_shape : tuple
(height, width) for grayscale images or (height, width, number of channels) for color images (default: None)
max_value : float
maximum pixel value, used for min-max normalization. Is often 255, if None the maximum value in the data set will be used (default: None)
min_value : float
maximum pixel value, used for min-max normalization. Is often 0, if None the minimum value in the data set will be used (default: 255)
show_plot : bool
Defines whether the plot should directly be plotted (default: True)
Examples
----------
from clustpy.data import load_nrletters, load_optdigits
X, _ = load_nrletters()
plot_image(X[0], False, (9, 7, 3), 255, 0, show_plot=True)
X, _ = load_optdigits()
plot_image(X[0], True, (8, 8), 255, 0, show_plot=True)
"""
assert img_data.ndim <= 3, "Image data can not have more than 3 dimensions."
# Data range must match float between [0..1] or int between [0..255] -> use min-max transform
if max_value is None:
max_value = np.max(img_data)
if min_value is None:
min_value = np.min(img_data)
img_data = (img_data - min_value) / (max_value - min_value)
# Reshape array data
if img_data.ndim == 1:
img_data = img_data.reshape(image_shape)
# Plot original image or a black-and-white version
if black_and_white:
plt.imshow(img_data, cmap="Greys")
else:
plt.imshow(img_data)
plt.axis('off')
if show_plot:
plt.show()
[docs]def plot_histogram(X: np.ndarray, labels: np.ndarray = None, density: bool = True, n_bins: int = 100,
show_legend: bool = True, container: plt.Axes = plt, show_plot: bool = True) -> None:
"""
Plot a histogram.
Parameters
----------
X : np.ndarray
the given data set
labels : np.ndarray
The cluster labels. Specifies the color of the plotted objects. Can be None (default: None)
density : bool
Defines whether a kernel density should be added to the histogram (default: True)
n_bins : int
Number of bins (default: 100)
show_legend : bool
Defines whether the legend of the histogram should be shown (default: True)
container : plt.Axes
The container to which the histogram is added.
If another container is defined, show_plot should usually be False (default: matplotlib.pyplot)
show_plot : bool
Defines whether the plot should directly be plotted (default: True)
"""
assert X.ndim == 1, "Data must be 1-dimensional"
# Plot histogram
if labels is not None:
unique_labels, cmap, norm = _get_cmap_and_norm(labels)
for lab in unique_labels:
# Get common label colors for histogram and density
hist_color = cmap(norm(lab))
container.hist(X[labels == lab], alpha=0.5, bins=n_bins, color=hist_color, range=(np.min(X), np.max(X)))
else:
container.hist(X, alpha=0.5, bins=n_bins, range=(np.min(X), np.max(X)))
# Plot densities
if density:
# Histogram and density should share same x-axis
twin_axis = container.twinx()
twin_axis.yaxis.set_visible(False)
if labels is not None:
for lab in unique_labels:
den_objects = X[labels == lab]
if den_objects.shape[0] >= _MIN_OBJECTS_FOR_DENS_PLOT:
hist_color = cmap(norm(lab))
kde = stats.gaussian_kde(den_objects)
steps = np.linspace(np.min(den_objects), np.max(den_objects), 1000)
twin_axis.plot(steps, kde(steps), color=hist_color)
elif X.shape[0] >= _MIN_OBJECTS_FOR_DENS_PLOT:
kde = stats.gaussian_kde(X)
steps = np.linspace(np.min(X), np.max(X), 1000)
twin_axis.plot(steps, kde(steps))
if show_legend and labels is not None:
_add_legend(container, unique_labels, cmap, norm)
if show_plot:
plt.show()
[docs]def plot_scatter_matrix(X: np.ndarray, labels: np.ndarray = None, centers: np.ndarray = None,
true_labels: np.ndarray = None, density: bool = True, n_bins: int = 100,
show_legend: bool = True, scattersize: int = 10, equal_axis: bool = False,
max_dimensions: int = 10, show_plot: bool = True) -> plt.Axes:
"""
Create a scatter matrix plot.
Visualizes a 2d scatter plot for each combination of features.
The center axis shows a histogram of each single feature.
Parameters
----------
X : np.ndarray
the given data set
labels : np.ndarray
The cluster labels. Specifies the color of the plotted objects. Can be None (default: None)
centers : np.ndarray
The cluster centers. Will be plotted as red dots labeled by the corresponding cluster id. Can be None (default: None)
true_labels : np.ndarray
The ground truth labels. Specifies the symbol of the plotted objects. Can be None (default: None)
density : bool
Defines whether a kernel density should be added to the histogram (default: True)
n_bins : int
Number of bins used for the histogram (default: 100)
show_legend : bool
Defines whether a legend should be shown (default: True)
scattersize : float
The size of the scatters (default: 10)
equal_axis : bool
Defines whether the axes are to be scaled to the same value range (default: False)
max_dimensions : int
Maximum Number of dimensions that should be plotted.
This value is intended to prevent the creation of overly complex plots that are very confusing and take a long time to create (default: 10)
show_plot : bool
Defines whether the plot should directly be plotted (default: True)
Returns
-------
axes : plt.Axes
The used matplotlib axes
"""
if X.shape[1] > max_dimensions:
print(
"[WARNING] Dimensionality of the dataset is larger than 10. Creation of scatter matrix plot will be aborted.")
# For single dimension only plot histogram
if X.shape[1] == 1:
plot_histogram(X[:, 0], labels, density, n_bins, show_legend)
return plt.gca()
else:
# Get unique labels and unique true labels
if labels is not None:
unique_labels, cmap, norm = _get_cmap_and_norm(labels)
# Create subplots
if equal_axis:
fig, axes = plt.subplots(nrows=X.shape[1], ncols=X.shape[1], sharey="all", sharex="all")
else:
fig, axes = plt.subplots(nrows=X.shape[1], ncols=X.shape[1], sharey="row", sharex="col")
fig.subplots_adjust(hspace=0.05, wspace=0.05)
for i in range(X.shape[1]):
for j in range(X.shape[1]):
ax = axes[i, j]
if i == j:
# Histogram plot
if i != 0:
ax.yaxis.set_visible(False)
if i != X.shape[1] - 1:
ax.xaxis.set_visible(False)
# Second plot for actual histogram (use container)
twin_axis = ax.twinx()
twin_axis.yaxis.set_visible(False)
plot_histogram(X[:, i], labels, density, n_bins, show_legend=False, container=twin_axis,
show_plot=False)
else:
# Scatter plot (use container)
local_centers = None if centers is None else centers[:, [j, i]]
plot_2d_data(X[:, [j, i]], labels, local_centers, true_labels, show_legend=False,
scattersize=scattersize,
equal_axis=False, container=ax, show_plot=False)
if show_legend and labels is not None:
_add_legend(fig, unique_labels, cmap, norm)
if show_plot:
plt.show()
return axes
def _add_legend(container: plt.Axes, unique_labels: np.ndarray, cmap: Colormap, norm: Normalize) -> None:
"""
Helper function to add a legend to the histogram.
Parameters
----------
container : plt.Axes
The container to which the legend is added.
unique_labels : np.ndarray
The unique labels that should be displayed in the legend
cmap : Colormap
the colormap
norm : Normalize
The Normalize object to pick the correct color
"""
patchlist = [mpatches.Patch(color=cmap(norm(lab)), label=lab) for lab in unique_labels]
container.legend(handles=patchlist, loc="center right")
def _get_cmap_and_norm(labels: np.ndarray) -> (np.ndarray, Colormap, Normalize):
"""
Helper function to get colormap and Normalization object.
Parameters
----------
labels : np.ndarray
The cluster labels
Returns
-------
tuple : (np.ndarray, Colormap, Normalize)
The unique labels ids,
The colormap,
The Normalize object to pick the correct color
"""
unique_labels = np.unique(labels)
# Manage colormap
cmap = cm.get_cmap('viridis', 12)
norm = Normalize(vmin=unique_labels[0], vmax=unique_labels[-1])
return unique_labels, cmap, norm