"""
@authors:
Collin Leiber,
Julian Schilcher
"""
import numpy as np
import torch
from clustpy.deep._utils import squared_euclidean_distance, encode_batchwise, predict_batchwise, \
embedded_kmeans_prediction
from clustpy.deep._train_utils import get_default_deep_clustering_initialization
from sklearn.cluster import KMeans
from clustpy.deep._abstract_deep_clustering_algo import _AbstractDeepClusteringAlgo
from clustpy.hierarchical._cluster_tree import BinaryClusterTree, _ClusterTreeNode
import tqdm
import copy
class _DeepECT_ClusterTreeNode(_ClusterTreeNode):
def set_center_weight_and_torch_labels(self, center: np.ndarray, weight: float, optimizer: torch.optim.Optimizer,
device: torch.device) -> None:
"""
Set the cluster center and cluster weight for this node.
Furthermore, create a copy of the labels as torch tensor that is saved on the specified device.
Parameters
----------
center : np.ndarray
The cluster center
weight : float
The cluster weight
optimizer : torch.optim.Optimizer
Optimizer for training
device : torch.device
device to be trained on
"""
self.center = torch.nn.Parameter(torch.tensor(center).to(device), requires_grad=True)
self.weight = weight
optimizer.add_param_group({"params": self.center})
self.torch_labels = torch.tensor(self.labels, dtype=torch.int32).to(device)
def update_parents_torch_labels(self, device: torch.device) -> None:
"""
Update the torch_labels parameter of parent nodes.
Has to be called when a new node has been added or a node has been deleted.
Parameters
----------
device : torch.device
device to be trained on
"""
parent_node_to_update = self.parent_node
while parent_node_to_update is not None:
new_torch_labels = torch.tensor(parent_node_to_update.labels, dtype=torch.int32).to(device)
if hasattr(parent_node_to_update, "torch_labels") and torch.equal(parent_node_to_update.torch_labels,
new_torch_labels):
# Torch labels were already updated
break
parent_node_to_update.torch_labels = new_torch_labels
parent_node_to_update = parent_node_to_update.parent_node
class _DeepECT_Module(torch.nn.Module):
"""
The _DeepECT_Module. Contains most of the algorithm specific procedures like the loss and tree-grow functions.
Parameters
----------
cluster_tree: BinaryClusterTree
The cluster tree
max_n_leaf_nodes : int
Maximum number of leaf nodes in the cluster tree
grow_interval : int
Number of epochs after which the the tree is grown
pruning_threshold : float
The threshold for pruning the tree
augmentation_invariance : bool
If True, augmented samples provided in custom_dataloaders[0] will be used to learn
cluster assignments that are invariant to the augmentation transformations (default: False)
"""
def __init__(self, cluster_tree: BinaryClusterTree, max_n_leaf_nodes: int, grow_interval: int,
pruning_threshold: float, augmentation_invariance: bool = False):
super().__init__()
# Create initial cluster tree
self.cluster_tree = cluster_tree
self.max_n_leaf_nodes = max_n_leaf_nodes
self.grow_interval = grow_interval
self.pruning_threshold = pruning_threshold
self.augmentation_invariance = augmentation_invariance
def predict_hard(self, embedded: torch.Tensor) -> torch.Tensor:
"""
Hard prediction of the given embedded samples. Returns the corresponding hard labels.
Uses the minimum squared Euclidean distance to the cluster centers of the leaf nodes to get the labels.
Parameters
----------
embedded : torch.Tensor
the embedded samples
Returns
-------
labels : torch.Tensor
the final labels
"""
leaf_nodes, _ = self.cluster_tree.get_leaf_and_split_nodes()
device = leaf_nodes[0].center.device
_, _, labels = self._get_labels_from_leafs(embedded.to(device), leaf_nodes)
labels = labels.detach().cpu()
return labels
def _get_labels_from_leafs(self, embedded: torch.Tensor, leaf_nodes: list) -> (
torch.Tensor, torch.Tensor, torch.Tensor):
"""
Get the cluster assignments of the current batch by considering the distance to the closest center of a leaf node.
The assignment of a sample to a cluster center is represented by the index of the center and by the actual label of the the assigned leaf node.
These values usually differ and both values are returned.
Parameters
----------
embedded : torch.Tensor
The embedded batch of data
leaf_nodes : list
list containing all leaf nodes within the cluster tree
Returns
-------
tuple : (torch.Tensor, torch.Tensor, torch.Tensor)
The centers of the leaf nodes,
The index of the cluster center assigned to each sample,
The labels of the samples
"""
leaf_centers = torch.stack([leaf.center for leaf in leaf_nodes], dim=0)
leaf_labels = torch.stack([leaf.torch_labels[0] for leaf in leaf_nodes])
# Get distances between points and centers. Get nearest center
squared_diffs = squared_euclidean_distance(embedded, leaf_centers)
cluster_center_assignments = (squared_diffs.min(dim=1)[1]).int()
labels = leaf_labels[cluster_center_assignments]
return leaf_centers, cluster_center_assignments, labels
def _grow_tree(self, testloader: torch.utils.data.DataLoader, neural_network: torch.nn.Module, leaf_nodes: list,
new_cluster_id: int, optimizer: torch.optim.Optimizer, device: torch.device,
random_state: np.random.RandomState) -> None:
"""
Grows the tree at the leaf node with the highest squared distances between its assigned samples and the center.
The distance is not normalized, so larger clusters will be weighted higher.
After the leaf node with highest squared distances has been identified, it will be split into two leaf nodes by performing bisecting KMeans.
Parameters
----------
testloader : torch.utils.data.DataLoader
dataloader to be used for updating the clustering parameters
neural_network : torch.nn.Module
the neural network
leaf_nodes : list
list containing all leaf nodes within the cluster tree
new_cluster_id : int
the new cluster ID that should be added to the tree
optimizer : torch.optim.Optimizer
Optimizer for training
device : torch.device
device to be trained on
random_state : np.random.RandomState
use a fixed random state to get a repeatable solution
"""
leaf_to_split = None
max_sum_of_squared = 0
embedded = encode_batchwise(testloader, neural_network)
embedded_torch = torch.from_numpy(embedded).to(device)
leaf_centers, cluster_center_assignments, labels = self._get_labels_from_leafs(embedded_torch, leaf_nodes)
# Search leaf node with max distances
squared_distances = (embedded_torch - leaf_centers[cluster_center_assignments]).pow(2).sum(1)
for leaf_id in range(leaf_centers.shape[0]):
squared_distances_clust = squared_distances[cluster_center_assignments == leaf_id]
# Check that cluster has more than 1 sample
if squared_distances_clust.shape[0] > 1:
sum_of_squared_clust = squared_distances_clust.sum()
if sum_of_squared_clust > max_sum_of_squared:
max_sum_of_squared = sum_of_squared_clust
leaf_to_split = leaf_id
# Split node
new_left_node, new_right_node = self.cluster_tree.split_cluster(
leaf_nodes[leaf_to_split].labels[0], new_cluster_id)
km = KMeans(n_clusters=2, n_init=20, random_state=random_state).fit(
embedded[cluster_center_assignments.detach().cpu().numpy() == leaf_to_split])
new_left_node.set_center_weight_and_torch_labels(km.cluster_centers_[0], 1, optimizer, device)
new_right_node.set_center_weight_and_torch_labels(km.cluster_centers_[1], 1, optimizer, device)
new_left_node.update_parents_torch_labels(
device) # Has to be called only once as the parent is the same for the left and right node
# Change old center from torch.nn.Parameter to regular Tensor
leaf_nodes[leaf_to_split].center = leaf_nodes[leaf_to_split].center.data
def _update_split_node_centers(self, split_nodes: list, leaf_nodes: list, labels: torch.Tensor) -> list:
"""
Update the centers and the weights of the split nods analytically as described in the paper.
Returns a list containing all split nodes whose weight is below the pruning threshold (can be empty).
Parameters
----------
split_nodes : list
list containing all split nodes within the cluster tree
leaf_nodes : list
list containing all leaf nodes within the cluster tree
labels : torch.Tensor
labels of the samples
Returns
-------
nodes_to_prune : list
list containing all split nodes whose weight is now below the pruning threshold
"""
nodes_to_prune = []
for node in split_nodes + leaf_nodes:
if not node.is_leaf_node():
# Update center of split nodes
left_child = node.left_node_
right_child = node.right_node_
node.center = (left_child.weight * left_child.center + right_child.weight * right_child.center) / (
left_child.weight + right_child.weight)
# Update weight of all nodes except root node
if node.parent_node is not None:
n_samples_in_node = torch.isin(labels, node.torch_labels).sum()
node.weight = 0.5 * node.weight + 0.5 * n_samples_in_node
if node.weight < self.pruning_threshold:
nodes_to_prune.append(node)
return nodes_to_prune
def _prune_tree(self, nodes_to_prune: list, device: torch.device) -> None:
"""
Delete all nodes within nodes_to_prune from the cluster tree.
Parameters
----------
nodes_to_prune : list
Contains all nodes that should be deleted. Can also be empty.
device : torch.device
device to be trained on
"""
for node in nodes_to_prune:
sibling = node.delete_node()
if sibling is not None:
sibling.update_parents_torch_labels(device)
def _node_center_loss(self, embedded: torch.Tensor, leaf_centers: torch.Tensor,
cluster_center_assignments: torch.Tensor, embedded_aug: torch.Tensor) -> torch.Tensor:
"""
Calculate the node center loss L_nc.
Parameters
----------
embedded : torch.Tensor
The embedded batch of data
leaf_centers : torch.Tensor
The centers of the leaf nodes
cluster_center_assignments : torch.Tensor
The index of the cluster center assigned to each sample
embedded_aug : torch.Tensor
the embedded augmented batch of data
Returns
-------
nc_loss : torch.Tensor
The node center loss
"""
unique_assignments = torch.unique(cluster_center_assignments)
# Note that batch must not contain samples from all leaf nodes
is_cluster_in_batch = [assign in unique_assignments for assign in range(leaf_centers.shape[0])]
leaf_centers_in_batch = leaf_centers[is_cluster_in_batch]
centers = torch.stack(
[torch.mean(embedded[cluster_center_assignments == assign], dim=0) for assign in unique_assignments], dim=0)
if self.augmentation_invariance:
centers_aug = torch.stack(
[torch.mean(embedded_aug[cluster_center_assignments == assign], dim=0) for assign in
unique_assignments], dim=0)
centers = (centers + centers_aug) / 2
# Calculate loss
sum_centers_dist = torch.linalg.vector_norm(leaf_centers_in_batch - centers.detach(), dim=1).sum()
nc_loss = sum_centers_dist / leaf_centers.shape[0]
return nc_loss
def _data_compression_loss(self, embedded: torch.Tensor, split_nodes: list, labels: torch.Tensor,
device: torch.device, embedded_aug: torch.Tensor) -> torch.Tensor:
"""
Calculate the data compression loss L_dc.
Parameters
----------
embedded : torch.Tensor
The embedded batch of data
split_nodes : list
list containing all split nodes within the cluster tree
labels : torch.Tensor
labels of the samples
device : torch.device
device to be trained on
embedded_aug : torch.Tensor
the embedded augmented batch of data
Returns
-------
dc_loss : torch.Tensor
The data compression loss
"""
dc_loss = torch.tensor(0.).to(device)
for node in split_nodes:
samples_in_left = torch.isin(labels, node.left_node_.torch_labels)
samples_in_right = torch.isin(labels, node.right_node_.torch_labels)
# Check if samples are contained in subtree
if torch.any(samples_in_left) or torch.any(samples_in_right):
proj = (node.left_node_.center - node.right_node_.center) / torch.linalg.vector_norm(
node.left_node_.center - node.right_node_.center).detach()
if torch.any(samples_in_left):
# Loss on left side
left_center = node.left_node_.center.detach()
dc_loss += torch.abs(torch.matmul(left_center - embedded[samples_in_left], proj)).sum()
if self.augmentation_invariance:
dc_loss += torch.abs(torch.matmul(left_center - embedded_aug[samples_in_left], proj)).sum()
if torch.any(samples_in_right):
# Loss on right side
right_center = node.right_node_.center.detach()
dc_loss += torch.abs(torch.matmul(right_center - embedded[samples_in_right], proj)).sum()
if self.augmentation_invariance:
dc_loss += torch.abs(torch.matmul(right_center - embedded_aug[samples_in_right], proj)).sum()
dc_loss = dc_loss / (2 * len(split_nodes) * embedded.shape[0])
if self.augmentation_invariance:
dc_loss /= 2
return dc_loss
def _loss(self, batch: list, neural_network: torch.nn.Module, ssl_loss_fn: torch.nn.modules.loss._Loss,
clustering_loss_weight: float, ssl_loss_weight: float, leaf_nodes: list, split_nodes: list,
device: torch.device) -> (torch.Tensor, torch.Tensor):
"""
Calculate the complete DeepECT + neural network loss.
Parameters
----------
batch : list
the minibatch
neural_network : torch.nn.Module
the neural network
ssl_loss_fn : torch.nn.modules.loss._Loss
self-supervised learning (ssl) loss function for training the network, e.g. reconstruction loss for autoencoders
clustering_loss_weight : float
weight of the clustering loss
ssl_loss_weight : float
weight of the self-supervised learning (ssl) loss
leaf_nodes : list
list containing all leaf nodes within the cluster tree
split_nodes : list
list containing all split nodes within the cluster tree
device : torch.device
device to be trained on
Returns
-------
loss : (torch.Tensor, torch.Tensor)
the final DeepECT loss,
the labels of the samples
"""
# compute self-supervised loss
if self.augmentation_invariance:
ssl_loss, embedded, _, embedded_aug, _ = neural_network.loss_augmentation(batch, ssl_loss_fn, device)
else:
ssl_loss, embedded, _ = neural_network.loss(batch, ssl_loss_fn, device)
embedded_aug = None
# calculate cluster loss
leaf_centers, cluster_center_assignments, labels = self._get_labels_from_leafs(embedded, leaf_nodes)
nc_loss = self._node_center_loss(embedded, leaf_centers, cluster_center_assignments, embedded_aug)
dc_loss = self._data_compression_loss(embedded, split_nodes, labels, device, embedded_aug)
# Combine losses
loss = clustering_loss_weight * (nc_loss + dc_loss) + ssl_loss_weight * ssl_loss
return loss, labels
def fit(self, neural_network: torch.nn.Module, trainloader: torch.utils.data.DataLoader,
testloader: torch.utils.data.DataLoader, n_epochs: int, device: torch.device,
optimizer: torch.optim.Optimizer, ssl_loss_fn: torch.nn.modules.loss._Loss, clustering_loss_weight: float,
ssl_loss_weight: float, random_state: np.random.RandomState) -> "_DeepECT_Module":
"""
Trains the _DeepECT_Module in place.
Parameters
----------
neural_network : torch.nn.Module
the neural network
trainloader : torch.utils.data.DataLoader
dataloader to be used for training
testloader : torch.utils.data.DataLoader
dataloader to be used for updating the clustering parameters
n_epochs : int
number of epochs for the clustering procedure
device : torch.device
device to be trained on
optimizer : torch.optim.Optimizer
Optimizer for training
ssl_loss_fn : torch.nn.modules.loss._Loss
self-supervised learning (ssl) loss function for training the network, e.g. reconstruction loss for autoencoders
clustering_loss_weight : float
weight of the clustering loss
ssl_loss_weight : float
weight of the self-supervised learning (ssl) loss
random_state : np.random.RandomState
use a fixed random state to get a repeatable solution
Returns
-------
self : _DeepECT_Module
This instance of the _DeepECT_Module
"""
cluster_id = 2 # Two clusters were created during the initialization of the algorithm
leaf_nodes, split_nodes = self.cluster_tree.get_leaf_and_split_nodes()
tbar = tqdm.trange(n_epochs, desc="DeepECT training")
for epoch in tbar:
# Update Network
total_loss = 0
with torch.no_grad():
# Grow tree
if (epoch % self.grow_interval == 0 or self.cluster_tree.n_leaf_nodes_ < 2) and len(
leaf_nodes) < self.max_n_leaf_nodes:
self._grow_tree(testloader, neural_network, leaf_nodes, cluster_id, optimizer, device, random_state)
cluster_id += 1
leaf_nodes, split_nodes = self.cluster_tree.get_leaf_and_split_nodes()
for batch in trainloader:
# Calculate loss
loss, labels = self._loss(batch, neural_network, ssl_loss_fn, clustering_loss_weight, ssl_loss_weight,
leaf_nodes, split_nodes, device)
total_loss += loss.item()
# Backward pass - update weights
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Adapt centers and weights of split nodes analytically
with torch.no_grad():
nodes_to_prune = self._update_split_node_centers(split_nodes, leaf_nodes, labels)
# Prune Tree
if len(nodes_to_prune) > 0:
self._prune_tree(nodes_to_prune, device)
leaf_nodes, split_nodes = self.cluster_tree.get_leaf_and_split_nodes()
postfix_str = {"Loss": total_loss}
tbar.set_postfix(postfix_str)
return self
def _deep_ect(X: np.ndarray, max_n_leaf_nodes: int, batch_size: int, pretrain_optimizer_params: dict,
clustering_optimizer_params: dict, pretrain_epochs: int, clustering_epochs: int, grow_interval: int,
pruning_threshold: float, optimizer_class: torch.optim.Optimizer,
ssl_loss_fn: torch.nn.modules.loss._Loss, neural_network: torch.nn.Module | tuple,
neural_network_weights: str, embedding_size: int, clustering_loss_weight: float, ssl_loss_weight: float,
custom_dataloaders: tuple, augmentation_invariance: bool, device: torch.device,
random_state: np.random.RandomState) -> (np.ndarray, np.ndarray, torch.nn.Module):
"""
Start the actual DeepECT clustering procedure on the input data set.
Parameters
----------
X : np.ndarray
The given data set. Can be a np.ndarray or a torch.Tensor
max_n_leaf_nodes : int
Maximum number of leaf nodes in the cluster tree
batch_size : int
Size of the data batches
pretrain_optimizer_params : dict
Parameters of the optimizer for the pretraining of the neural network, includes the learning rate
clustering_optimizer_params : dict
Parameters of the optimizer for the actual clustering procedure, includes the learning rate
pretrain_epochs : int
Number of epochs for the pretraining of the neural network
clustering_epochs : int
Number of epochs for the actual clustering procedure
grow_interval : int
Number of epochs after which the the tree is grown
pruning_threshold : float
The threshold for pruning the tree
optimizer_class : torch.optim.Optimizer
The optimizer class
ssl_loss_fn : torch.nn.modules.loss._Loss
self-supervised learning (ssl) loss function for training the network, e.g. reconstruction loss for autoencoders
neural_network : torch.nn.Module | tuple
the input neural network.
Can also be a tuple consisting of the neural network class (torch.nn.Module) and the initialization parameters (dict)
neural_network_weights : str
Path to a file containing the state_dict of the neural_network.
embedding_size : int
size of the embedding within the neural network
clustering_loss_weight : float
weight of the clustering loss
ssl_loss_weight : float
weight of the self-supervised learning (ssl) loss
custom_dataloaders : tuple
tuple consisting of a trainloader (random order) at the first and a test loader (non-random order) at the second position.
Can also be a tuple of strings, where the first entry is the path to a saved trainloader and the second entry the path to a saved testloader.
In this case the dataloaders will be loaded by torch.load(PATH).
If None, the default dataloaders will be used
augmentation_invariance : bool
If True, augmented samples provided in custom_dataloaders[0] will be used to learn cluster assignments that are invariant to the augmentation transformations
device : torch.device
The device on which to perform the computations
random_state : np.random.RandomState
use a fixed random state to get a repeatable solution
Returns
-------
tuple : (np.ndarray, np.ndarray, torch.nn.Module)
The tree as identified DeepECT,
The labels as identified by DeepECT,
The final neural network
"""
# Get initial setting (device, dataloaders, pretrained AE and initial clustering result)
device, trainloader, testloader, _, neural_network, _, _, _, init_leafnode_centers, _ = get_default_deep_clustering_initialization(
X, 2, batch_size, pretrain_optimizer_params, pretrain_epochs, optimizer_class, ssl_loss_fn,
neural_network, embedding_size, custom_dataloaders, KMeans, {"n_init": 20}, device,
random_state, neural_network_weights=neural_network_weights)
cluster_tree = BinaryClusterTree(_DeepECT_ClusterTreeNode)
# Setup DeepECT Module
deepect_module = _DeepECT_Module(cluster_tree, max_n_leaf_nodes, grow_interval, pruning_threshold,
augmentation_invariance).to(device)
# Use DeepECT optimizer parameters (usually learning rate is reduced by a magnitude of 10)
optimizer = optimizer_class(list(neural_network.parameters()), **clustering_optimizer_params)
# DeepECT Training loop
left_node, right_node = cluster_tree.split_cluster(0, 1)
left_node.set_center_weight_and_torch_labels(init_leafnode_centers[0], 1, optimizer, device)
right_node.set_center_weight_and_torch_labels(init_leafnode_centers[1], 1, optimizer, device)
left_node.update_parents_torch_labels(
device) # Has to be called only once as the parent is the same for the left and right node
# Change old center from torch.nn.Parameter to regular Tensor
# Start fit
deepect_module.fit(neural_network, trainloader, testloader, clustering_epochs, device, optimizer, ssl_loss_fn,
clustering_loss_weight, ssl_loss_weight, random_state)
# Get labels
labels = predict_batchwise(testloader, neural_network, deepect_module)
return cluster_tree, labels, neural_network
[docs]class DeepECT(_AbstractDeepClusteringAlgo):
"""
The Deep Embedded Cluster Tree (DeepECT) algorithm.
First, a neural network will be trained (will be skipped if input neural network is given).
Afterward, a cluster tree will be grown and the network will be optimized using the DeepECT loss function.
Parameters
----------
max_n_leaf_nodes : int
Maximum number of leaf nodes in the cluster tree (default: 20)
batch_size : int
Size of the data batches (default: 256)
pretrain_optimizer_params : dict
parameters of the optimizer for the pretraining of the neural network, includes the learning rate (default: {"lr": 1e-3})
clustering_optimizer_params : dict
parameters of the optimizer for the actual clustering procedure, includes the learning rate (default: {"lr": 1e-4})
pretrain_epochs : int
number of epochs for the pretraining of the neural network (default: 50)
clustering_epochs : int
Number of epochs for the actual clustering procedure (default: 200)
grow_interval : int
Number of epochs after which the the tree is grown (default: 2)
pruning_threshold : float
The threshold for pruning the tree (default: 0.1)
optimizer_class : torch.optim.Optimizer
The optimizer class (default: torch.optim.Adam)
ssl_loss_fn : torch.nn.modules.loss._Loss
self-supervised learning (ssl) loss function for training the network, e.g. reconstruction loss for autoencoders (default: torch.nn.MSELoss())
neural_network : torch.nn.Module | tuple
the input neural network. If None, a new FeedforwardAutoencoder will be created.
Can also be a tuple consisting of the neural network class (torch.nn.Module) and the initialization parameters (dict) (default: None)
neural_network_weights : str
Path to a file containing the state_dict of the neural_network (default: None)
embedding_size : int
Size of the embedding within the neural network (default: 10)
clustering_loss_weight : float
weight of the clustering loss (default: 1.0)
ssl_loss_weight : float
weight of the self-supervised learning (ssl) loss (default: 1.0)
custom_dataloaders : tuple
tuple consisting of a trainloader (random order) at the first and a test loader (non-random order) at the second position.
Can also be a tuple of strings, where the first entry is the path to a saved trainloader and the second entry the path to a saved testloader.
In this case the dataloaders will be loaded by torch.load(PATH).
If None, the default dataloaders will be used (default: None)
augmentation_invariance : bool
If True, augmented samples provided in custom_dataloaders[0] will be used to learn cluster assignments that are invariant to the augmentation transformations (default: False)
device : torch.device
The device on which to perform the computations.
If device is None then it will be automatically chosen: if a gpu is available the gpu with the highest amount of free memory will be chosen (default: None)
random_state : np.random.RandomState
Use a fixed random state to get a repeatable solution. Can also be of type int (default: None)
Attributes
----------
labels_ : np.ndarray
The final labels (obtained by a final KMeans execution)
tree_ : PredictionClusterTree
The prediction cluster tree after training
neural_network : torch.nn.Module
The final neural network
"""
def __init__(self, max_n_leaf_nodes: int = 20, batch_size: int = 256, pretrain_optimizer_params: dict = None,
clustering_optimizer_params: dict = None, pretrain_epochs: int = 50, clustering_epochs: int = 200,
grow_interval: int = 2, pruning_threshold: float = 0.1,
optimizer_class: torch.optim.Optimizer = torch.optim.Adam,
ssl_loss_fn: torch.nn.modules.loss._Loss = torch.nn.MSELoss(),
neural_network: torch.nn.Module | tuple = None, neural_network_weights: str = None,
embedding_size: int = 10, clustering_loss_weight: float = 1., ssl_loss_weight: float = 1.,
custom_dataloaders: tuple = None, augmentation_invariance: bool = False,
device: torch.device = None, random_state: np.random.RandomState | int = None):
super().__init__(batch_size, neural_network, neural_network_weights, embedding_size, device, random_state)
self.max_n_leaf_nodes = max_n_leaf_nodes
self.pretrain_optimizer_params = {
"lr": 1e-3} if pretrain_optimizer_params is None else pretrain_optimizer_params
self.clustering_optimizer_params = {
"lr": 1e-4} if clustering_optimizer_params is None else clustering_optimizer_params
self.pretrain_epochs = pretrain_epochs
self.clustering_epochs = clustering_epochs
self.grow_interval = grow_interval
self.pruning_threshold = pruning_threshold
self.optimizer_class = optimizer_class
self.ssl_loss_fn = ssl_loss_fn
self.clustering_loss_weight = clustering_loss_weight
self.ssl_loss_weight = ssl_loss_weight
self.custom_dataloaders = custom_dataloaders
self.augmentation_invariance = augmentation_invariance
[docs] def fit(self, X: np.ndarray, y: np.ndarray = None) -> "DeepECT":
"""
Initiate the actual clustering process on the input data set.
The resulting cluster labels will be stored in the labels_ attribute.
Parameters
----------
X : np.ndarray
the given data set
y : np.ndarray
the labels (can be ignored)
Returns
-------
self : DeepECT
This instance of the DeepECT algorithm
"""
super().fit(X, y)
tree, labels, neural_network = _deep_ect(X, self.max_n_leaf_nodes, self.batch_size,
self.pretrain_optimizer_params, self.clustering_optimizer_params,
self.pretrain_epochs, self.clustering_epochs, self.grow_interval,
self.pruning_threshold, self.optimizer_class, self.ssl_loss_fn,
self.neural_network, self.neural_network_weights, self.embedding_size,
self.clustering_loss_weight, self.ssl_loss_weight,
self.custom_dataloaders, self.augmentation_invariance, self.device,
self.random_state)
self.tree_ = tree
self.labels_ = labels
self.neural_network = neural_network
return self
[docs] def predict(self, X: np.ndarray) -> np.ndarray:
"""
Predicts the labels of the input data.
Parameters
----------
X : np.ndarray
input data
Returns
-------
predicted_labels : np.ndarray
The predicted labels
"""
X_embed = self.transform(X)
leaf_nodes, _ = self.tree_.get_leaf_and_split_nodes()
leaf_centers = np.array([leaf.center.data.detach().cpu().numpy() for leaf in leaf_nodes])
leaf_labels = np.array([leaf.labels[0] for leaf in leaf_nodes])
cluster_center_assignments = embedded_kmeans_prediction(X_embed, leaf_centers)
predicted_labels = leaf_labels[cluster_center_assignments]
return predicted_labels
[docs] def flat_clustering(self, n_leaf_nodes_to_keep: int) -> np.ndarray:
"""
Transform the predicted labels into a flat clustering result by only keeping n_leaf_nodes_to_keep leaf nodes in the tree.
Returns labels as if the clustering procedure would have stopped at the specified number of nodes.
Note that each leaf node corresponds to a cluster.
Parameters
----------
n_leaf_nodes_to_keep : int
The number of leaf nodes to keep in the cluster tree
Returns
-------
labels_pruned : np.ndarray
The new cluster labels
"""
assert self.labels_ is not None, "The DeepECT algorithm has not run yet. Use the fit() function first."
tree_copy = copy.deepcopy(self.tree_)
labels_pruned = tree_copy.prune_to_n_leaf_nodes(n_leaf_nodes_to_keep, self.labels_)
return labels_pruned