"""
@authors:
Lukas Miklautz,
Dominik Mautz,
Collin Leiber
"""
from clustpy.deep._utils import encode_batchwise, squared_euclidean_distance, predict_batchwise, \
embedded_kmeans_prediction
from clustpy.deep._train_utils import get_default_deep_clustering_initialization
from clustpy.deep._abstract_deep_clustering_algo import _AbstractDeepClusteringAlgo
import torch
import numpy as np
from sklearn.cluster import KMeans
from sklearn.base import ClusterMixin
import tqdm
def _dec(X: np.ndarray, n_clusters: int, alpha: float, batch_size: int, pretrain_optimizer_params: dict,
clustering_optimizer_params: dict, pretrain_epochs: int, clustering_epochs: int,
optimizer_class: torch.optim.Optimizer, ssl_loss_fn: torch.nn.modules.loss._Loss,
neural_network: torch.nn.Module | tuple, neural_network_weights: str, embedding_size: int,
ssl_loss_weight: float, clustering_loss_weight: float, custom_dataloaders: tuple,
augmentation_invariance: bool, initial_clustering_class: ClusterMixin, initial_clustering_params: dict,
device: torch.device, random_state: np.random.RandomState) -> (
np.ndarray, np.ndarray, np.ndarray, np.ndarray, torch.nn.Module):
"""
Start the actual DEC clustering procedure on the input data set.
Parameters
----------
X : np.ndarray / torch.Tensor
the given data set. Can be a np.ndarray or a torch.Tensor
n_clusters : int
number of clusters. Can be None if a corresponding initial_clustering_class is given, that can determine the number of clusters, e.g. DBSCAN
alpha : float
alpha value for the prediction
batch_size : int
size of the data batches
pretrain_optimizer_params : dict
parameters of the optimizer for the pretraining of the neural network, includes the learning rate
clustering_optimizer_params : dict
parameters of the optimizer for the actual clustering procedure, includes the learning rate
pretrain_epochs : int
number of epochs for the pretraining of the neural network
clustering_epochs : int
number of epochs for the actual clustering procedure
optimizer_class : torch.optim.Optimizer
the optimizer
ssl_loss_fn : torch.nn.modules.loss._Loss
self-supervised learning (ssl) loss function for training the network, e.g. reconstruction loss for autoencoders
neural_network : torch.nn.Module | tuple
the input neural network.
Can also be a tuple consisting of the neural network class (torch.nn.Module) and the initialization parameters (dict)
neural_network_weights : str
Path to a file containing the state_dict of the neural_network.
embedding_size : int
size of the embedding within the neural network
ssl_loss_weight : float
weight of the self-supervised learning (ssl) loss
clustering_loss_weight : float
weight of the clustering loss
custom_dataloaders : tuple
tuple consisting of a trainloader (random order) at the first and a test loader (non-random order) at the second position.
Can also be a tuple of strings, where the first entry is the path to a saved trainloader and the second entry the path to a saved testloader.
In this case the dataloaders will be loaded by torch.load(PATH).
If None, the default dataloaders will be used
augmentation_invariance : bool
If True, augmented samples provided in custom_dataloaders[0] will be used to learn
cluster assignments that are invariant to the augmentation transformations
initial_clustering_class : ClusterMixin
clustering class to obtain the initial cluster labels after the pretraining
initial_clustering_params : dict
parameters for the initial clustering class
device : torch.device
The device on which to perform the computations
random_state : np.random.RandomState
use a fixed random state to get a repeatable solution
Returns
-------
tuple : (np.ndarray, np.ndarray, np.ndarray, np.ndarray, torch.nn.Module)
The labels as identified by a final KMeans execution,
The cluster centers as identified by a final KMeans execution,
The labels as identified by DEC after the training terminated,
The cluster centers as identified by DEC after the training terminated,
The final neural network
"""
# Get initial setting (device, dataloaders, pretrained AE and initial clustering result)
device, trainloader, testloader, _, neural_network, _, n_clusters, _, init_centers, _ = get_default_deep_clustering_initialization(
X, n_clusters, batch_size, pretrain_optimizer_params, pretrain_epochs, optimizer_class, ssl_loss_fn,
neural_network, embedding_size, custom_dataloaders, initial_clustering_class, initial_clustering_params, device,
random_state, neural_network_weights=neural_network_weights)
# Setup DEC Module
dec_module = _DEC_Module(init_centers, alpha, augmentation_invariance).to(device)
# Use DEC optimizer parameters (usually learning rate is reduced by a magnitude of 10)
optimizer = optimizer_class(list(neural_network.parameters()) + list(dec_module.parameters()),
**clustering_optimizer_params)
# DEC Training loop
dec_module.fit(neural_network, trainloader, clustering_epochs, device, optimizer, ssl_loss_fn,
ssl_loss_weight, clustering_loss_weight)
# Get labels
dec_labels = predict_batchwise(testloader, neural_network, dec_module)
dec_centers = dec_module.centers.detach().cpu().numpy()
# Do reclustering with Kmeans
embedded_data = encode_batchwise(testloader, neural_network)
kmeans = KMeans(n_clusters=n_clusters, random_state=random_state)
kmeans.fit(embedded_data)
return kmeans.labels_, kmeans.cluster_centers_, dec_labels, dec_centers, neural_network
def _dec_predict(centers: torch.Tensor, embedded: torch.Tensor, alpha: float, weights: torch.Tensor) -> torch.Tensor:
"""
Predict soft cluster labels given embedded samples.
Parameters
----------
centers : torch.Tensor
the cluster centers
embedded : torch.Tensor
the embedded samples
alpha : float
the alpha value
weights : torch.Tensor
feature weights for the squared Euclidean distance
Returns
-------
prob : torch.Tensor
The predicted soft labels
"""
squared_diffs = squared_euclidean_distance(embedded, centers, weights)
numerator = (1.0 + squared_diffs / alpha).pow(-1.0 * (alpha + 1.0) / 2.0)
denominator = numerator.sum(1)
prob = numerator / denominator.unsqueeze(1)
return prob
def _dec_compression_value(pred_labels: torch.Tensor) -> torch.Tensor:
"""
Get the DEC compression values.
Parameters
----------
pred_labels : torch.Tensor
the predictions of the embedded samples.
Returns
-------
p : torch.Tensor
The compression values
"""
soft_freq = pred_labels.sum(0)
squared_pred = pred_labels.pow(2)
normalized_squares = squared_pred / soft_freq.unsqueeze(0)
sum_normalized_squares = normalized_squares.sum(1)
p = normalized_squares / sum_normalized_squares.unsqueeze(1)
return p
def _dec_compression_loss_fn(pred_labels: torch.Tensor, target_p: torch.Tensor = None) -> torch.Tensor:
"""
Calculate the loss of DEC by computing the DEC compression value.
Parameters
----------
pred_labels : torch.Tensor
the predictions of the embedded samples.
target_p : torch.Tensor
dec_compression_value used as pseudo target labels
Returns
-------
loss : torch.Tensor
The final loss
"""
if target_p is None:
target_p = _dec_compression_value(pred_labels).detach().data
loss = -1.0 * torch.mean(torch.sum(target_p * torch.log(pred_labels + 1e-8), dim=1))
return loss
class _DEC_Module(torch.nn.Module):
"""
The _DEC_Module. Contains most of the algorithm specific procedures like the loss and prediction functions.
Parameters
----------
init_centers : np.ndarray
The initial cluster centers
alpha : double
alpha value for the prediction method
augmentation_invariance : bool
If True, augmented samples provided in custom_dataloaders[0] will be used to learn
cluster assignments that are invariant to the augmentation transformations (default: False)
Attributes
----------
alpha : float
the alpha value
centers : torch.Tensor:
the cluster centers
augmentation_invariance : bool
Is augmentation invariance used
"""
def __init__(self, init_centers: np.ndarray, alpha: float, augmentation_invariance: bool = False):
super().__init__()
self.alpha = alpha
self.augmentation_invariance = augmentation_invariance
# Centers are learnable parameters
self.centers = torch.nn.Parameter(torch.tensor(init_centers), requires_grad=True)
def predict(self, embedded: torch.Tensor, weights: torch.Tensor = None) -> torch.Tensor:
"""
Soft prediction of given embedded samples. Returns the corresponding soft labels.
Parameters
----------
embedded : torch.Tensor
the embedded samples
weights : torch.Tensor
feature weights for the squared Euclidean distance within the dec_predict method (default: None)
Returns
-------
pred : torch.Tensor
The predicted soft labels
"""
pred = _dec_predict(self.centers, embedded, self.alpha, weights=weights)
return pred
def predict_hard(self, embedded: torch.Tensor, weights: torch.Tensor = None) -> torch.Tensor:
"""
Hard prediction of the given embedded samples. Returns the corresponding hard labels.
Uses the soft prediction method and then applies argmax.
Parameters
----------
embedded : torch.Tensor
the embedded samples
weights : torch.Tensor
feature weights for the squared Euclidean distance within the dec_predict method (default: None)
Returns
-------
pred_hard : torch.Tensor
The predicted hard labels
"""
pred_hard = self.predict(embedded, weights=weights).argmax(1)
return pred_hard
def dec_loss(self, embedded: torch.Tensor, weights: torch.Tensor = None) -> torch.Tensor:
"""
Calculate the DEC loss of given embedded samples.
Parameters
----------
embedded : torch.Tensor
the embedded samples
weights : torch.Tensor
feature weights for the squared Euclidean distance within the dec_predict method (default: None)
Returns
-------
loss : torch.Tensor
the final DEC loss
"""
prediction = _dec_predict(self.centers, embedded, self.alpha, weights=weights)
loss = _dec_compression_loss_fn(prediction)
return loss
def dec_augmentation_invariance_loss(self, embedded: torch.Tensor, embedded_aug: torch.Tensor,
weights: torch.Tensor = None) -> torch.Tensor:
"""
Calculate the DEC loss of given embedded samples with augmentation invariance.
Parameters
----------
embedded : torch.Tensor
the embedded samples
embedded_aug : torch.Tensor
the embedded augmented samples
weights : torch.Tensor
feature weights for the squared Euclidean distance within the dec_predict method (default: None)
Returns
-------
loss : torch.Tensor
the final DEC loss
"""
prediction = _dec_predict(self.centers, embedded, self.alpha, weights=weights)
# Predict pseudo cluster labels with clean samples
clean_target_p = _dec_compression_value(prediction).detach().data
# Calculate loss from clean prediction and clean targets
clean_loss = _dec_compression_loss_fn(prediction, clean_target_p)
# Predict pseudo cluster labels with augmented samples
aug_prediction = _dec_predict(self.centers, embedded_aug, self.alpha, weights=weights)
# Calculate loss from augmented prediction and reused clean targets to enforce that the cluster assignment is invariant against augmentations
aug_loss = _dec_compression_loss_fn(aug_prediction, clean_target_p)
# average losses
loss = (clean_loss + aug_loss) / 2
return loss
def _loss(self, batch: list, neural_network: torch.nn.Module, clustering_loss_weight: float,
ssl_loss_weight: float, ssl_loss_fn: torch.nn.modules.loss._Loss, device: torch.device) -> torch.Tensor:
"""
Calculate the complete DEC + optional neural network loss.
Parameters
----------
batch : list
the minibatch
neural_network : torch.nn.Module
the neural network
clustering_loss_weight : float
weight of the clustering loss
ssl_loss_weight : float
weight of the clustering loss
ssl_loss_fn : torch.nn.modules.loss._Loss
loss function for the reconstruction
device : torch.device
device to be trained on
Returns
-------
loss : torch.Tensor
the final DEC loss
"""
loss = torch.tensor(0.).to(device)
# Reconstruction loss is not included in DEC
if ssl_loss_weight != 0:
if self.augmentation_invariance:
ssl_loss, embedded, _, embedded_aug, _ = neural_network.loss_augmentation(batch, ssl_loss_fn, device)
else:
ssl_loss, embedded, _ = neural_network.loss(batch, ssl_loss_fn, device)
loss += ssl_loss_weight * ssl_loss
else:
if self.augmentation_invariance:
aug_data = batch[1].to(device)
embedded_aug = neural_network.encode(aug_data)
orig_data = batch[2].to(device)
embedded = neural_network.encode(orig_data)
else:
batch_data = batch[1].to(device)
embedded = neural_network.encode(batch_data)
# CLuster loss
if self.augmentation_invariance:
cluster_loss = self.dec_augmentation_invariance_loss(embedded, embedded_aug)
else:
cluster_loss = self.dec_loss(embedded)
loss += cluster_loss * clustering_loss_weight
return loss
def fit(self, neural_network: torch.nn.Module, trainloader: torch.utils.data.DataLoader, n_epochs: int,
device: torch.device, optimizer: torch.optim.Optimizer, ssl_loss_fn: torch.nn.modules.loss._Loss,
ssl_loss_weight: float, clustering_loss_weight: float) -> '_DEC_Module':
"""
Trains the _DEC_Module in place.
Parameters
----------
neural_network : torch.nn.Module
the neural network
trainloader : torch.utils.data.DataLoader
dataloader to be used for training
n_epochs : int
number of epochs for the clustering procedure
device : torch.device
device to be trained on
optimizer : torch.optim.Optimizer
the optimizer for training
ssl_loss_fn : torch.nn.modules.loss._Loss
self-supervised learning (ssl) loss function for training the network, e.g. reconstruction loss for autoencoders
ssl_loss_weight : float
weight of the self-supervised learning (ssl) loss
clustering_loss_weight : float
weight of the clustering loss
Returns
-------
self : _DEC_Module
this instance of the _DEC_Module
"""
tbar = tqdm.trange(n_epochs, desc="DEC training")
for _ in tbar:
total_loss = 0
for batch in trainloader:
loss = self._loss(batch, neural_network, clustering_loss_weight, ssl_loss_weight, ssl_loss_fn,
device)
total_loss += loss.item()
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
postfix_str = {"Loss": total_loss}
tbar.set_postfix(postfix_str)
return self
[docs]class DEC(_AbstractDeepClusteringAlgo):
"""
The Deep Embedded Clustering (DEC) algorithm.
First, a neural_network will be trained (will be skipped if input neural network is given).
Afterward, KMeans identifies the initial clusters.
Last, the network will be optimized using the DEC loss function.
Parameters
----------
n_clusters : int
number of clusters. Can be None if a corresponding initial_clustering_class is given, that can determine the number of clusters, e.g. DBSCAN
alpha : float
alpha value for the prediction (default: 1.0)
batch_size : int
size of the data batches (default: 256)
pretrain_optimizer_params : dict
parameters of the optimizer for the pretraining of the neural network, includes the learning rate (default: {"lr": 1e-3})
clustering_optimizer_params : dict
parameters of the optimizer for the actual clustering procedure, includes the learning rate (default: {"lr": 1e-4})
pretrain_epochs : int
number of epochs for the pretraining of the neural network (default: 100)
clustering_epochs : int
number of epochs for the actual clustering procedure (default: 150)
optimizer_class : torch.optim.Optimizer
the optimizer class (default: torch.optim.Adam)
ssl_loss_fn : torch.nn.modules.loss._Loss
self-supervised learning (ssl) loss function for training the network, e.g. reconstruction loss for autoencoders (default: torch.nn.MSELoss())
neural_network : torch.nn.Module | tuple
the input neural network. If None, a new FeedforwardAutoencoder will be created.
Can also be a tuple consisting of the neural network class (torch.nn.Module) and the initialization parameters (dict) (default: None)
neural_network_weights : str
Path to a file containing the state_dict of the neural_network (default: None)
embedding_size : int
size of the embedding within the neural network (default: 10)
clustering_loss_weight : float
weight of the clustering loss compared to the reconstruction loss (default: 1.0)
custom_dataloaders : tuple
tuple consisting of a trainloader (random order) at the first and a test loader (non-random order) at the second position.
Can also be a tuple of strings, where the first entry is the path to a saved trainloader and the second entry the path to a saved testloader.
In this case the dataloaders will be loaded by torch.load(PATH).
If None, the default dataloaders will be used (default: None)
augmentation_invariance : bool
If True, augmented samples provided in custom_dataloaders[0] will be used to learn
cluster assignments that are invariant to the augmentation transformations (default: False)
initial_clustering_class : ClusterMixin
clustering class to obtain the initial cluster labels after the pretraining (default: KMeans)
initial_clustering_params : dict
parameters for the initial clustering class (default: {})
device : torch.device
The device on which to perform the computations.
If device is None then it will be automatically chosen: if a gpu is available the gpu with the highest amount of free memory will be chosen (default: None)
random_state : np.random.RandomState | int
use a fixed random state to get a repeatable solution. Can also be of type int (default: None)
Attributes
----------
labels_ : np.ndarray
The final labels (obtained by a final KMeans execution)
cluster_centers_ : np.ndarray
The final cluster centers (obtained by a final KMeans execution)
dec_labels_ : np.ndarray
The final DEC labels
dec_cluster_centers_ : np.ndarray
The final DEC cluster centers
neural_network : torch.nn.Module
The final neural network
Examples
----------
>>> from clustpy.data import create_subspace_data
>>> from clustpy.deep import DEC
>>> data, labels = create_subspace_data(1500, subspace_features=(3, 50), random_state=1)
>>> dec = DEC(n_clusters=3, pretrain_epochs=3, clustering_epochs=3)
>>> dec.fit(data)
References
----------
Xie, Junyuan, Ross Girshick, and Ali Farhadi. "Unsupervised deep embedding for clustering analysis."
International conference on machine learning. 2016.
"""
def __init__(self, n_clusters: int, alpha: float = 1.0, batch_size: int = 256,
pretrain_optimizer_params: dict = None, clustering_optimizer_params: dict = None,
pretrain_epochs: int = 100, clustering_epochs: int = 150,
optimizer_class: torch.optim.Optimizer = torch.optim.Adam,
ssl_loss_fn: torch.nn.modules.loss._Loss = torch.nn.MSELoss(),
neural_network: torch.nn.Module | tuple = None, neural_network_weights: str = None,
embedding_size: int = 10, clustering_loss_weight: float = 1., custom_dataloaders: tuple = None,
augmentation_invariance: bool = False, initial_clustering_class: ClusterMixin = KMeans,
initial_clustering_params: dict = None, device: torch.device = None,
random_state: np.random.RandomState | int = None):
super().__init__(batch_size, neural_network, neural_network_weights, embedding_size, device, random_state)
self.n_clusters = n_clusters
self.alpha = alpha
self.pretrain_optimizer_params = {
"lr": 1e-3} if pretrain_optimizer_params is None else pretrain_optimizer_params
self.clustering_optimizer_params = {
"lr": 1e-4} if clustering_optimizer_params is None else clustering_optimizer_params
self.pretrain_epochs = pretrain_epochs
self.clustering_epochs = clustering_epochs
self.optimizer_class = optimizer_class
self.ssl_loss_fn = ssl_loss_fn
self.clustering_loss_weight = clustering_loss_weight
self.custom_dataloaders = custom_dataloaders
self.augmentation_invariance = augmentation_invariance
self.initial_clustering_class = initial_clustering_class
self.initial_clustering_params = {} if initial_clustering_params is None else initial_clustering_params
self.ssl_loss_weight = 0 # DEC does not use ssl loss when clustering
[docs] def fit(self, X: np.ndarray, y: np.ndarray = None) -> 'DEC':
"""
Initiate the actual clustering process on the input data set.
The resulting cluster labels will be stored in the labels_ attribute.
Parameters
----------
X : np.ndarray
the given data set
y : np.ndarray
the labels (can be ignored)
Returns
-------
self : DEC
this instance of the DEC algorithm
"""
super().fit(X, y)
kmeans_labels, kmeans_centers, dec_labels, dec_centers, neural_network = _dec(X, self.n_clusters, self.alpha,
self.batch_size,
self.pretrain_optimizer_params,
self.clustering_optimizer_params,
self.pretrain_epochs,
self.clustering_epochs,
self.optimizer_class,
self.ssl_loss_fn,
self.neural_network,
self.neural_network_weights,
self.embedding_size,
self.ssl_loss_weight,
self.clustering_loss_weight,
self.custom_dataloaders,
self.augmentation_invariance,
self.initial_clustering_class,
self.initial_clustering_params,
self.device, self.random_state)
self.labels_ = kmeans_labels
self.cluster_centers_ = kmeans_centers
self.dec_labels_ = dec_labels
self.dec_cluster_centers_ = dec_centers
self.neural_network = neural_network
return self
[docs] def predict(self, X: np.ndarray) -> np.ndarray:
"""
Predicts the labels of the input data.
Parameters
----------
X : np.ndarray
input data
Returns
-------
predicted_labels : np.ndarray
The predicted labels
"""
X_embed = self.transform(X)
predicted_labels = embedded_kmeans_prediction(X_embed, self.cluster_centers_)
return predicted_labels
[docs]class IDEC(DEC):
"""
The Improved Deep Embedded Clustering (IDEC) algorithm.
Is equal to the DEC algorithm but uses the self-supervised learning loss also during the clustering optimization.
Further, clustering_loss_weight is set to 0.1 instead of 1 when using the default settings.
Parameters
----------
n_clusters : int
number of clusters. Can be None if a corresponding initial_clustering_class is given, that can determine the number of clusters, e.g. DBSCAN
alpha : float
alpha value for the prediction (default: 1.0)
batch_size : int
size of the data batches (default: 256)
pretrain_optimizer_params : dict
parameters of the optimizer for the pretraining of the neural network, includes the learning rate (default: {"lr": 1e-3})
clustering_optimizer_params : dict
parameters of the optimizer for the actual clustering procedure, includes the learning rate (default: {"lr": 1e-4})
pretrain_epochs : int
number of epochs for the pretraining of the neural network (default: 100)
clustering_epochs : int
number of epochs for the actual clustering procedure (default: 150)
optimizer_class : torch.optim.Optimizer
the optimizer class (default: torch.optim.Adam)
ssl_loss_fn : torch.nn.modules.loss._Loss
self-supervised learning (ssl) loss function for training the network, e.g. reconstruction loss for autoencoders (default: torch.nn.MSELoss())
neural_network : torch.nn.Module | tuple
the input neural network. If None, a new FeedforwardAutoencoder will be created.
Can also be a tuple consisting of the neural network class (torch.nn.Module) and the initialization parameters (dict) (default: None)
neural_network_weights : str
Path to a file containing the state_dict of the neural_network (default: None)
embedding_size : int
size of the embedding within the neural network (default: 10)
clustering_loss_weight : float
weight of the clustering loss compared to the reconstruction loss (default: 0.1)
ssl_loss_weight : float
weight of the self-supervised learning (ssl) loss (default: 1.0)
custom_dataloaders : tuple
tuple consisting of a trainloader (random order) at the first and a test loader (non-random order) at the second position.
Can also be a tuple of strings, where the first entry is the path to a saved trainloader and the second entry the path to a saved testloader.
In this case the dataloaders will be loaded by torch.load(PATH).
If None, the default dataloaders will be used (default: None)
augmentation_invariance : bool
If True, augmented samples provided in custom_dataloaders[0] will be used to learn
cluster assignments that are invariant to the augmentation transformations (default: False)
initial_clustering_class : ClusterMixin
clustering class to obtain the initial cluster labels after the pretraining (default: KMeans)
initial_clustering_params : dict
parameters for the initial clustering class (default: {})
device : torch.device
The device on which to perform the computations.
If device is None then it will be automatically chosen: if a gpu is available the gpu with the highest amount of free memory will be chosen (default: None)
random_state : np.random.RandomState | int
use a fixed random state to get a repeatable solution. Can also be of type int (default: None)
Attributes
----------
labels_ : np.ndarray
The final labels (obtained by a final KMeans execution)
cluster_centers_ : np.ndarray
The final cluster centers (obtained by a final KMeans execution)
dec_labels_ : np.ndarray
The final DEC labels
dec_cluster_centers_ : np.ndarray
The final DEC cluster centers
neural_network : torch.nn.Module
The final neural network
Examples
----------
>>> from clustpy.data import create_subspace_data
>>> from clustpy.deep import IDEC
>>> data, labels = create_subspace_data(1500, subspace_features=(3, 50), random_state=1)
>>> idec = IDEC(n_clusters=3, pretrain_epochs=3, clustering_epochs=3)
>>> idec.fit(data)
References
----------
Guo, Xifeng, et al. "Improved deep embedded clustering with local structure preservation." IJCAI. 2017.
"""
def __init__(self, n_clusters: int, alpha: float = 1.0, batch_size: int = 256,
pretrain_optimizer_params: dict = None,
clustering_optimizer_params: dict = None, pretrain_epochs: int = 100,
clustering_epochs: int = 150, optimizer_class: torch.optim.Optimizer = torch.optim.Adam,
ssl_loss_fn: torch.nn.modules.loss._Loss = torch.nn.MSELoss(),
neural_network: torch.nn.Module | tuple = None, neural_network_weights: str = None,
embedding_size: int = 10, clustering_loss_weight: float = 0.1, ssl_loss_weight: float = 1.0,
custom_dataloaders: tuple = None, augmentation_invariance: bool = False,
initial_clustering_class: ClusterMixin = KMeans, initial_clustering_params: dict = None,
device: torch.device = None, random_state: np.random.RandomState | int = None):
super().__init__(n_clusters, alpha, batch_size, pretrain_optimizer_params, clustering_optimizer_params,
pretrain_epochs, clustering_epochs, optimizer_class, ssl_loss_fn, neural_network,
neural_network_weights, embedding_size, clustering_loss_weight, custom_dataloaders,
augmentation_invariance, initial_clustering_class,
initial_clustering_params, device, random_state)
self.ssl_loss_weight = ssl_loss_weight