package de.jungblut.ner;

import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.SingleEntryDoubleVector;
import de.jungblut.math.sparse.SparseDoubleVector;
import de.jungblut.math.tuple.Tuple;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:de/jungblut/ner/SparseFeatureExtractorHelper.class */
public final class SparseFeatureExtractorHelper<K> {
    private final List<K> words;
    private final List<Integer> labels;
    private final SequenceFeatureExtractor<K> extractor;
    private final HashSet<Integer> classSet;
    private int classes;
    private String[] dicts;

    public SparseFeatureExtractorHelper(List<K> list, List<Integer> list2, SequenceFeatureExtractor<K> sequenceFeatureExtractor) {
        this.words = list;
        this.labels = list2;
        this.extractor = sequenceFeatureExtractor;
        this.classSet = new HashSet<>(list2);
        this.classes = this.classSet.size();
    }

    public SparseFeatureExtractorHelper(List<K> list, List<Integer> list2, SequenceFeatureExtractor<K> sequenceFeatureExtractor, String[] strArr) {
        this(list, list2, sequenceFeatureExtractor);
        this.dicts = strArr;
    }

    public Tuple<DoubleVector[], DoubleVector[]> vectorize() {
        return extractInternal(this.words, this.labels);
    }

    public DoubleVector vectorize(K k) {
        return vectorize((SparseFeatureExtractorHelper<K>) k, (Integer) null);
    }

    public DoubleVector vectorize(K k, Integer num) {
        List<String> computeFeatures = this.extractor.computeFeatures(Arrays.asList(k), num == null ? 0 : num.intValue(), 0);
        SparseDoubleVector sparseDoubleVector = new SparseDoubleVector(this.dicts.length);
        Iterator<String> it = computeFeatures.iterator();
        while (it.hasNext()) {
            int binarySearch = Arrays.binarySearch(this.dicts, it.next());
            if (binarySearch >= 0) {
                sparseDoubleVector.set(binarySearch, 1.0d);
            }
        }
        return sparseDoubleVector;
    }

    public Tuple<DoubleVector[], DoubleVector[]> vectorize(List<K> list, List<Integer> list2) {
        return vectorizeAdditionals(list, list2);
    }

    public Tuple<DoubleVector[], DoubleVector[]> vectorizeAdditionals(List<K> list, List<Integer> list2) {
        return extractInternal(list, list2);
    }

    public DoubleVector[] vectorizeEachLabel(List<K> list) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            if (i == 0) {
                arrayList.add(this.extractor.computeFeatures(list, 0, i));
            } else {
                Iterator<Integer> it = this.classSet.iterator();
                while (it.hasNext()) {
                    arrayList.add(this.extractor.computeFeatures(list, it.next().intValue(), i));
                }
            }
        }
        DoubleVector[] doubleVectorArr = new DoubleVector[arrayList.size()];
        int length = this.dicts.length;
        for (int i2 = 0; i2 < doubleVectorArr.length; i2++) {
            doubleVectorArr[i2] = new SparseDoubleVector(length);
            Iterator it2 = ((List) arrayList.get(i2)).iterator();
            while (it2.hasNext()) {
                int binarySearch = Arrays.binarySearch(this.dicts, (String) it2.next());
                if (binarySearch >= 0) {
                    doubleVectorArr[i2].set(binarySearch, 1.0d);
                }
            }
        }
        return doubleVectorArr;
    }

    public String[] getDictionary() {
        return this.dicts;
    }

    private Tuple<DoubleVector[], DoubleVector[]> extractInternal(List<K> list, List<Integer> list2) {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        while (i < list.size()) {
            arrayList.add(this.extractor.computeFeatures(list, i == 0 ? 0 : list2.get(i - 1).intValue(), i));
            i++;
        }
        DoubleVector[] doubleVectorArr = new DoubleVector[arrayList.size()];
        DoubleVector[] doubleVectorArr2 = new DoubleVector[arrayList.size()];
        if (this.dicts == null) {
            HashSet hashSet = new HashSet();
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                hashSet.addAll((List) it.next());
            }
            this.dicts = (String[]) hashSet.toArray(new String[hashSet.size()]);
            Arrays.sort(this.dicts);
        }
        int length = this.dicts.length;
        for (int i2 = 0; i2 < doubleVectorArr.length; i2++) {
            doubleVectorArr[i2] = new SparseDoubleVector(length);
            Iterator it2 = ((List) arrayList.get(i2)).iterator();
            while (it2.hasNext()) {
                int binarySearch = Arrays.binarySearch(this.dicts, (String) it2.next());
                if (binarySearch >= 0) {
                    doubleVectorArr[i2].set(binarySearch, 1.0d);
                }
            }
            if (this.classes == 2) {
                doubleVectorArr2[i2] = new SingleEntryDoubleVector(list2.get(i2).intValue());
            } else {
                doubleVectorArr2[i2] = new SparseDoubleVector(this.classes);
                doubleVectorArr2[i2].set(list2.get(i2).intValue(), 1.0d);
            }
        }
        return new Tuple<>(doubleVectorArr, doubleVectorArr2);
    }
}
