package de.jungblut.ner;

import de.jungblut.datastructure.ArrayUtils;
import de.jungblut.distance.CosineDistance;
import de.jungblut.distance.DistanceMeasurer;
import de.jungblut.distance.SimilarityMeasurer;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.tuple.Tuple;
import gnu.trove.list.array.TIntArrayList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:de/jungblut/ner/IterativeSimilarityAggregation.class */
public final class IterativeSimilarityAggregation {
    private static final Logger LOG = LogManager.getLogger(IterativeSimilarityAggregation.class);
    private final double alpha;
    private final SimilarityMeasurer similarityMeasurer;
    private final String[] seedTokens;
    private int[] seedIndices;
    private String[] termNodes;
    private DoubleMatrix weightMatrix;

    public IterativeSimilarityAggregation(String[] strArr, Tuple<String[], DoubleMatrix> tuple) {
        this(strArr, tuple, 0.5d, new CosineDistance());
    }

    public IterativeSimilarityAggregation(String[] strArr, Tuple<String[], DoubleMatrix> tuple, double d, DistanceMeasurer distanceMeasurer) {
        this.seedTokens = strArr;
        this.termNodes = (String[]) tuple.getFirst();
        this.weightMatrix = ((DoubleMatrix) tuple.getSecond()).transpose();
        this.alpha = d;
        this.similarityMeasurer = new SimilarityMeasurer(distanceMeasurer);
        init();
    }

    private void init() {
        TIntArrayList tIntArrayList = new TIntArrayList();
        for (String str : this.seedTokens) {
            int find = ArrayUtils.find(this.termNodes, str);
            if (find >= 0) {
                tIntArrayList.add(find);
            } else {
                LOG.info("Seed token \"" + str + "\" could not be found in the term list!");
            }
        }
        this.seedIndices = tIntArrayList.toArray();
    }

    public String[] startStaticThresholding(double d, int i, boolean z) {
        DenseDoubleVector computeRelevanceScore = computeRelevanceScore(this.seedIndices);
        int[] filterRelevantItems = filterRelevantItems(computeRelevanceScore, 0.0d);
        int i2 = 0;
        while (true) {
            int[] topRankedItems = getTopRankedItems(rankScores(this.alpha, computeRelevanceScore, computeRelevanceScore(filterRelevantItems)), d);
            boolean z2 = filterRelevantItems.length == topRankedItems.length;
            if (z2) {
                int i3 = 0;
                while (true) {
                    if (i3 >= topRankedItems.length) {
                        break;
                    }
                    if (topRankedItems[i3] != filterRelevantItems[i3]) {
                        z2 = false;
                        break;
                    }
                    i3++;
                }
            }
            if (z) {
                LOG.info(i2 + " | Top ranked item size: " + topRankedItems.length);
            }
            filterRelevantItems = topRankedItems;
            if (z2 || (i > 0 && i2 > i)) {
                break;
            }
            i2++;
        }
        String[] strArr = new String[filterRelevantItems.length];
        for (int i4 = 0; i4 < filterRelevantItems.length; i4++) {
            strArr[i4] = this.termNodes[filterRelevantItems[i4]];
        }
        return strArr;
    }

    static int[] getTopRankedItems(DoubleVector doubleVector, double d) {
        DoubleVector deepCopy = doubleVector.deepCopy();
        int[] iArr = new int[deepCopy.getLength()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = i;
        }
        for (int i2 = 0; i2 < deepCopy.getLength() - 1; i2++) {
            int i3 = i2;
            for (int i4 = i2 + 1; i4 < deepCopy.getLength(); i4++) {
                if (deepCopy.get(i4) > deepCopy.get(i3)) {
                    i3 = i4;
                }
            }
            if (i2 != i3) {
                double d2 = deepCopy.get(i3);
                deepCopy.set(i3, deepCopy.get(i2));
                deepCopy.set(i2, d2);
                ArrayUtils.swap(iArr, i3, i2);
            }
        }
        TIntArrayList tIntArrayList = new TIntArrayList();
        for (int i5 : iArr) {
            if (doubleVector.get(i5) <= d) {
                break;
            }
            tIntArrayList.add(i5);
        }
        return tIntArrayList.toArray();
    }

    private DenseDoubleVector computeRelevanceScore(int[] iArr) {
        int length = this.termNodes.length;
        DenseDoubleVector denseDoubleVector = new DenseDoubleVector(length);
        double length2 = 1.0d / iArr.length;
        for (int i = 0; i < length; i++) {
            double d = 0.0d;
            for (int i2 : iArr) {
                DoubleVector columnVector = this.weightMatrix.getColumnVector(i);
                DoubleVector columnVector2 = this.weightMatrix.getColumnVector(i2);
                double d2 = 0.0d;
                if (columnVector != null && columnVector2 != null) {
                    d2 = this.similarityMeasurer.measureSimilarity(columnVector, columnVector2);
                }
                d += d2;
            }
            denseDoubleVector.set(i, length2 * d);
        }
        return denseDoubleVector;
    }

    static DoubleVector rankScores(double d, DenseDoubleVector denseDoubleVector, DenseDoubleVector denseDoubleVector2) {
        return denseDoubleVector2.multiply(d).add(denseDoubleVector.multiply(d));
    }

    static int[] filterRelevantItems(DenseDoubleVector denseDoubleVector, double d) {
        TIntArrayList tIntArrayList = new TIntArrayList();
        for (int i = 0; i < denseDoubleVector.getLength(); i++) {
            if (denseDoubleVector.get(i) > d) {
                tIntArrayList.add(i);
            }
        }
        return tIntArrayList.toArray();
    }
}
