package de.bioforscher.singa.mathematics.algorithms.clustering;

import de.bioforscher.singa.mathematics.matrices.LabeledMatrix;
import de.bioforscher.singa.mathematics.matrices.LabeledRegularMatrix;
import de.bioforscher.singa.mathematics.matrices.Matrix;
import de.bioforscher.singa.mathematics.matrices.RegularMatrix;
import de.bioforscher.singa.mathematics.vectors.RegularVector;
import de.bioforscher.singa.mathematics.vectors.Vectors;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/bioforscher/singa/mathematics/algorithms/clustering/AffinityPropagation.class */
public class AffinityPropagation<DataType> implements Clustering<DataType> {
    private static final Logger logger = LoggerFactory.getLogger(AffinityPropagation.class);
    private static final int MIN_STABLE_EPOCHS = 10;
    private final List<DataType> dataPoints;
    private final int dataSize;
    private final int maximalEpochs;
    private final LabeledMatrix<DataType> distanceMatrix;
    private final double lambda;
    private LabeledMatrix<DataType> similarityMatrix;
    private LabeledMatrix<DataType> availabilityMatrix;
    private LabeledMatrix<DataType> responsibilityMatrix;
    private int epoch;
    private List<List<DataType>> exemplarDecisions;
    private Map<DataType, List<DataType>> clusters;

    /* loaded from: input_file:de/bioforscher/singa/mathematics/algorithms/clustering/AffinityPropagation$Builder.class */
    public static class Builder<DataType> implements DataStep<DataType>, MatrixStep<DataType>, DistanceStep<DataType>, ParameterStep<DataType> {
        private static final double DEFAULT_SELF_SIMILARITY = -0.5d;
        private static final double DEFAULT_LAMBDA = 0.5d;
        private static final int DEFAULT_MAXIMAL_EPOCHS = 1000;
        private List<DataType> dataPoints;
        private LabeledMatrix<DataType> matrix;
        private double selfSimilarity = DEFAULT_SELF_SIMILARITY;
        private double lambda = DEFAULT_LAMBDA;
        private int maximalEpochs = DEFAULT_MAXIMAL_EPOCHS;
        private boolean distance;

        @Override // de.bioforscher.singa.mathematics.algorithms.clustering.AffinityPropagation.DataStep
        public MatrixStep<DataType> dataPoints(List<DataType> list) {
            this.dataPoints = list;
            return this;
        }

        @Override // de.bioforscher.singa.mathematics.algorithms.clustering.AffinityPropagation.MatrixStep
        public DistanceStep<DataType> matrix(LabeledMatrix<DataType> labeledMatrix) {
            this.matrix = labeledMatrix;
            return this;
        }

        @Override // de.bioforscher.singa.mathematics.algorithms.clustering.AffinityPropagation.ParameterStep
        public ParameterStep<DataType> selfSimilarity(double d) {
            this.selfSimilarity = d;
            return this;
        }

        @Override // de.bioforscher.singa.mathematics.algorithms.clustering.AffinityPropagation.DistanceStep
        public ParameterStep<DataType> isDistance(boolean z) {
            this.distance = z;
            return this;
        }

        @Override // de.bioforscher.singa.mathematics.algorithms.clustering.AffinityPropagation.ParameterStep
        public ParameterStep<DataType> lambda(double d) {
            this.lambda = d;
            return this;
        }

        @Override // de.bioforscher.singa.mathematics.algorithms.clustering.AffinityPropagation.ParameterStep
        public ParameterStep<DataType> maximalEpochs(int i) {
            this.maximalEpochs = i;
            return this;
        }

        @Override // de.bioforscher.singa.mathematics.algorithms.clustering.AffinityPropagation.ParameterStep
        public AffinityPropagation<DataType> run() {
            return new AffinityPropagation<>(this);
        }
    }

    /* loaded from: input_file:de/bioforscher/singa/mathematics/algorithms/clustering/AffinityPropagation$DataStep.class */
    public interface DataStep<DataType> {
        MatrixStep<DataType> dataPoints(List<DataType> list);
    }

    /* loaded from: input_file:de/bioforscher/singa/mathematics/algorithms/clustering/AffinityPropagation$DistanceStep.class */
    public interface DistanceStep<DataType> {
        ParameterStep<DataType> isDistance(boolean z);
    }

    /* loaded from: input_file:de/bioforscher/singa/mathematics/algorithms/clustering/AffinityPropagation$MatrixStep.class */
    public interface MatrixStep<DataType> {
        DistanceStep<DataType> matrix(LabeledMatrix<DataType> labeledMatrix);
    }

    /* loaded from: input_file:de/bioforscher/singa/mathematics/algorithms/clustering/AffinityPropagation$ParameterStep.class */
    public interface ParameterStep<DataType> {
        ParameterStep<DataType> selfSimilarity(double d);

        ParameterStep<DataType> lambda(double d);

        ParameterStep<DataType> maximalEpochs(int i);

        AffinityPropagation<DataType> run();
    }

    private AffinityPropagation(Builder<DataType> builder) {
        double[][] elements;
        logger.info("affinity propagation initialized with {} data points", Integer.valueOf(((Builder) builder).dataPoints.size()));
        this.dataPoints = ((Builder) builder).dataPoints;
        this.dataSize = this.dataPoints.size();
        this.similarityMatrix = ((Builder) builder).matrix;
        this.lambda = ((Builder) builder).lambda;
        double d = ((Builder) builder).selfSimilarity;
        this.maximalEpochs = ((Builder) builder).maximalEpochs;
        checkInput(this.dataPoints, this.similarityMatrix);
        if (((Builder) builder).distance) {
            this.distanceMatrix = this.similarityMatrix;
            elements = ((Matrix) this.similarityMatrix.additivelyInvert()).getElements();
            for (int i = 0; i < elements.length; i++) {
                elements[i][i] = -d;
            }
        } else {
            elements = this.similarityMatrix.getElements();
            this.distanceMatrix = new LabeledRegularMatrix(new RegularMatrix(elements).additivelyInvert().getElements());
            this.distanceMatrix.setRowLabels(this.dataPoints);
            this.distanceMatrix.setColumnLabels(this.dataPoints);
            for (int i2 = 0; i2 < elements.length; i2++) {
                elements[i2][i2] = d;
            }
        }
        this.similarityMatrix = new LabeledRegularMatrix(elements);
        this.similarityMatrix.setRowLabels(this.dataPoints);
        this.similarityMatrix.setColumnLabels(this.dataPoints);
        initialize();
        run();
    }

    public static <DataType> DataStep<DataType> create() {
        return new Builder();
    }

    @Override // de.bioforscher.singa.mathematics.algorithms.clustering.Clustering
    public List<DataType> getDataPoints() {
        return this.dataPoints;
    }

    @Override // de.bioforscher.singa.mathematics.algorithms.clustering.Clustering
    public LabeledMatrix<DataType> getDistanceMatrix() {
        return this.distanceMatrix;
    }

    @Override // de.bioforscher.singa.mathematics.algorithms.clustering.Clustering
    public Map<DataType, List<DataType>> getClusters() {
        return this.clusters;
    }

    private void initialize() {
        this.responsibilityMatrix = new LabeledRegularMatrix(new double[this.dataSize][this.dataSize]);
        this.responsibilityMatrix.setRowLabels(this.dataPoints);
        this.responsibilityMatrix.setColumnLabels(this.dataPoints);
        this.availabilityMatrix = new LabeledRegularMatrix(new double[this.dataSize][this.dataSize]);
        this.availabilityMatrix.setRowLabels(this.dataPoints);
        this.availabilityMatrix.setColumnLabels(this.dataPoints);
        this.exemplarDecisions = new ArrayList();
    }

    private void checkInput(List<DataType> list, LabeledMatrix<DataType> labeledMatrix) {
        List<DataType> rowLabels = labeledMatrix.getRowLabels();
        Objects.requireNonNull(rowLabels);
        if (!list.equals(rowLabels)) {
            throw new IllegalArgumentException("The data does not match the labels of the provided matrix.");
        }
    }

    private void run() {
        while (this.epoch < this.maximalEpochs) {
            updateResponsibilities();
            updateAvailabilities();
            assignExemplars();
            assignClusters();
            if (isConverged()) {
                break;
            }
            this.epoch++;
            if (this.epoch == this.maximalEpochs) {
                logger.info("terminating after reaching maximal epoch limit");
            }
        }
        logger.info("obtained {} clusters", Integer.valueOf(this.clusters.size()));
    }

    private void assignClusters() {
        this.clusters = new HashMap();
        for (DataType datatype : this.dataPoints) {
            double d = -1.7976931348623157E308d;
            DataType datatype2 = null;
            Iterator<DataType> it = this.exemplarDecisions.get(this.exemplarDecisions.size() - 1).iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                DataType next = it.next();
                if (next.equals(datatype)) {
                    datatype2 = datatype;
                    break;
                }
                double valueForLabel = this.similarityMatrix.getValueForLabel(datatype, next);
                if (valueForLabel > d) {
                    d = valueForLabel;
                    datatype2 = next;
                }
            }
            if (this.clusters.containsKey(datatype2)) {
                this.clusters.get(datatype2).add(datatype);
            } else {
                ArrayList arrayList = new ArrayList();
                arrayList.add(datatype);
                this.clusters.put(datatype2, arrayList);
            }
        }
    }

    private void assignExemplars() {
        ArrayList arrayList = new ArrayList();
        LabeledRegularMatrix labeledRegularMatrix = new LabeledRegularMatrix(((Matrix) this.responsibilityMatrix.add(this.availabilityMatrix)).getElements());
        labeledRegularMatrix.setRowLabels(this.dataPoints);
        labeledRegularMatrix.setColumnLabels(this.dataPoints);
        for (int i = 0; i < labeledRegularMatrix.getRowDimension(); i++) {
            if (labeledRegularMatrix.getElement(i, i) > 0.0d) {
                arrayList.add(this.dataPoints.get(i));
            }
        }
        this.exemplarDecisions.add(arrayList);
    }

    private void updateResponsibilities() {
        double[][] dArr = new double[this.dataSize][this.dataSize];
        Matrix matrix = (Matrix) this.similarityMatrix.add(this.availabilityMatrix);
        for (int i = 0; i < this.dataPoints.size(); i++) {
            for (int i2 = 0; i2 < this.dataPoints.size(); i2++) {
                double[] copyOf = Arrays.copyOf(matrix.getRow(i).getElements(), matrix.getRow(i).getElements().length);
                copyOf[i2] = -1.7976931348623157E308d;
                RegularVector regularVector = new RegularVector(copyOf);
                dArr[i][i2] = this.similarityMatrix.getElement(i, i2) - regularVector.getElement(Vectors.getIndexWithMaximalElement(regularVector));
            }
        }
        this.responsibilityMatrix = applyLambda(new LabeledRegularMatrix(dArr), this.responsibilityMatrix);
    }

    private void updateAvailabilities() {
        double[][] dArr = new double[this.dataSize][this.dataSize];
        for (int i = 0; i < this.dataSize; i++) {
            for (int i2 = 0; i2 < this.dataSize; i2++) {
                double[] elements = this.responsibilityMatrix.getColumn(i).getElements();
                double d = 0.0d;
                for (int i3 = 0; i3 < elements.length; i3++) {
                    if (i3 != i && i3 != i2 && elements[i3] > 0.0d) {
                        d += elements[i3];
                    }
                }
                if (i == i2) {
                    dArr[i2][i] = d;
                } else {
                    double element = d + this.responsibilityMatrix.getElement(i, i);
                    if (element < 0.0d) {
                        dArr[i2][i] = element;
                    } else {
                        dArr[i2][i] = 0.0d;
                    }
                }
            }
        }
        this.availabilityMatrix = applyLambda(new LabeledRegularMatrix(dArr), this.availabilityMatrix);
    }

    private LabeledMatrix<DataType> applyLambda(LabeledMatrix<DataType> labeledMatrix, LabeledMatrix<DataType> labeledMatrix2) {
        LabeledRegularMatrix labeledRegularMatrix = new LabeledRegularMatrix(((Matrix) labeledMatrix.multiply(1.0d - this.lambda).add(labeledMatrix2.multiply(this.lambda))).getElements());
        labeledRegularMatrix.setRowLabels(this.dataPoints);
        labeledRegularMatrix.setColumnLabels(this.dataPoints);
        return labeledRegularMatrix;
    }

    private boolean isConverged() {
        if (this.exemplarDecisions.size() < MIN_STABLE_EPOCHS) {
            return false;
        }
        boolean z = true;
        int size = this.exemplarDecisions.size() - MIN_STABLE_EPOCHS;
        for (int size2 = this.exemplarDecisions.size() - 1; size2 > size; size2--) {
            if (!this.exemplarDecisions.get(size2).equals(this.exemplarDecisions.get(size2 - 1))) {
                z = false;
            }
        }
        if (z) {
            logger.debug("converged in epoch {}/{}", Integer.valueOf(this.epoch), Integer.valueOf(this.maximalEpochs));
        } else {
            logger.debug("not converged in epoch {}/{}", Integer.valueOf(this.epoch), Integer.valueOf(this.maximalEpochs));
        }
        return z;
    }

    public LabeledMatrix<DataType> getSimilarityMatrix() {
        return this.similarityMatrix;
    }

    public void setSimilarityMatrix(LabeledMatrix<DataType> labeledMatrix) {
        this.similarityMatrix = labeledMatrix;
    }

    public LabeledMatrix<DataType> getAvailabilityMatrix() {
        return this.availabilityMatrix;
    }

    public void setAvailabilityMatrix(LabeledMatrix<DataType> labeledMatrix) {
        this.availabilityMatrix = labeledMatrix;
    }

    public LabeledMatrix<DataType> getResponsibilityMatrix() {
        return this.responsibilityMatrix;
    }

    public void setResponsibilityMatrix(LabeledMatrix<DataType> labeledMatrix) {
        this.responsibilityMatrix = labeledMatrix;
    }
}
