package org.deeplearning4j.models.word2vec.iterator;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.regex.Pattern;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.datasets.iterator.DataSetFetcher;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.movingwindow.Window;
import org.deeplearning4j.text.movingwindow.WindowConverter;
import org.deeplearning4j.text.movingwindow.Windows;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/word2vec/iterator/Word2VecDataFetcher.class */
public class Word2VecDataFetcher implements DataSetFetcher {
    private static final long serialVersionUID = 3245955804749769475L;
    private transient Iterator<File> files;
    private Word2Vec vec;
    private List<String> labels;
    private int batch;
    private List<Window> cache = new ArrayList();
    private int totalExamples;
    private String path;
    private static Pattern begin = Pattern.compile("<[A-Z]+>");
    private static Pattern end = Pattern.compile("</[A-Z]+>");
    private static final Logger log = LoggerFactory.getLogger(Word2VecDataFetcher.class);

    public Word2VecDataFetcher(String str, Word2Vec word2Vec, List<String> list) {
        this.labels = new ArrayList();
        if (word2Vec == null || list == null || list.isEmpty()) {
            throw new IllegalArgumentException("Unable to initialize due to missing argument or empty label applyTransformToDestination");
        }
        this.vec = word2Vec;
        this.labels = list;
        this.path = str;
    }

    private DataSet fromCache() {
        INDArray create = Nd4j.create(this.batch, this.vec.lookupTable().layerSize() * this.vec.getWindow());
        INDArray create2 = Nd4j.create(this.batch, this.labels.size());
        for (int i = 0; i < this.batch; i++) {
            create.putRow(i, WindowConverter.asExampleMatrix(this.cache.get(i), this.vec));
            int indexOf = this.labels.indexOf(this.cache.get(i).getLabel());
            if (indexOf < 0) {
                indexOf = 0;
            }
            create2.putRow(i, FeatureUtil.toOutcomeVector(indexOf, this.labels.size()));
        }
        return new DataSet(create, create2);
    }

    public DataSet next() {
        if (this.cache.size() >= this.batch || !this.files.hasNext()) {
            return fromCache();
        }
        try {
            LineIterator lineIterator = FileUtils.lineIterator(this.files.next());
            while (lineIterator.hasNext()) {
                List<Window> windows = Windows.windows(lineIterator.nextLine());
                if (!windows.isEmpty() || !lineIterator.hasNext()) {
                    if (windows.size() < this.batch) {
                        INDArray create = Nd4j.create(windows.size(), this.vec.lookupTable().layerSize() * this.vec.getWindow());
                        INDArray create2 = Nd4j.create(this.batch, this.labels.size());
                        for (int i = 0; i < windows.size(); i++) {
                            create.putRow(i, WindowConverter.asExampleMatrix(this.cache.get(i), this.vec));
                            int indexOf = this.labels.indexOf(windows.get(i).getLabel());
                            if (indexOf < 0) {
                                indexOf = 0;
                            }
                            create2.putRow(i, FeatureUtil.toOutcomeVector(indexOf, this.labels.size()));
                        }
                        return new DataSet(create, create2);
                    }
                    INDArray create3 = Nd4j.create(this.batch, this.vec.lookupTable().layerSize() * this.vec.getWindow());
                    INDArray create4 = Nd4j.create(this.batch, this.labels.size());
                    for (int i2 = 0; i2 < this.batch; i2++) {
                        create3.putRow(i2, WindowConverter.asExampleMatrix(this.cache.get(i2), this.vec));
                        int indexOf2 = this.labels.indexOf(windows.get(i2).getLabel());
                        if (indexOf2 < 0) {
                            indexOf2 = 0;
                        }
                        create4.putRow(i2, FeatureUtil.toOutcomeVector(indexOf2, this.labels.size()));
                    }
                    if (windows.size() > this.batch) {
                        this.cache.addAll(windows.subList(this.batch, windows.size()));
                    }
                    return new DataSet(create3, create4);
                }
            }
            return null;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public int totalExamples() {
        return this.totalExamples;
    }

    public int inputColumns() {
        return this.vec.lookupTable().layerSize() * this.vec.getWindow();
    }

    public int totalOutcomes() {
        return this.labels.size();
    }

    public void reset() {
        this.files = FileUtils.iterateFiles(new File(this.path), (String[]) null, true);
        this.cache.clear();
    }

    public int cursor() {
        return 0;
    }

    public boolean hasMore() {
        return this.files.hasNext() || !this.cache.isEmpty();
    }

    public void fetch(int i) {
        this.batch = i;
    }

    public Iterator<File> getFiles() {
        return this.files;
    }

    public Word2Vec getVec() {
        return this.vec;
    }

    public static Pattern getBegin() {
        return begin;
    }

    public static Pattern getEnd() {
        return end;
    }

    public List<String> getLabels() {
        return this.labels;
    }

    public int getBatch() {
        return this.batch;
    }

    public List<Window> getCache() {
        return this.cache;
    }
}
