package org.deeplearning4j.models.glove;

import akka.actor.ActorRef;
import akka.actor.ActorSystem;
import akka.actor.Props;
import akka.routing.RoundRobinPool;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.CounterMap;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.glove.actor.CoOccurrenceActor;
import org.deeplearning4j.models.glove.actor.SentenceWork;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.deeplearning4j.text.movingwindow.Util;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/glove/CoOccurrences.class */
public class CoOccurrences implements Serializable {
    private transient TokenizerFactory tokenizerFactory;
    private transient SentenceIterator sentenceIterator;
    private int windowSize;
    protected transient VocabCache cache;
    protected InvertedIndex index;
    protected transient ActorSystem trainingSystem;
    protected boolean symmetric;
    private Counter<Integer> sentenceOccurrences;
    private CounterMap<String, String> coOCurreneCounts;
    private static final Logger log = LoggerFactory.getLogger(CoOccurrences.class);
    private List<Pair<String, String>> coOccurrences;

    /* loaded from: input_file:org/deeplearning4j/models/glove/CoOccurrences$Builder.class */
    public static class Builder {
        private SentenceIterator sentenceIterator;
        private VocabCache cache;
        private TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
        private int windowSize = 15;
        private CounterMap<String, String> coOCurreneCounts = Util.parallelCounterMap();
        private boolean symmetric = true;

        public Builder symmetric(boolean z) {
            this.symmetric = z;
            return this;
        }

        public Builder tokenizer(TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder iterate(SentenceIterator sentenceIterator) {
            this.sentenceIterator = sentenceIterator;
            return this;
        }

        public Builder windowSize(int i) {
            this.windowSize = i;
            return this;
        }

        public Builder cache(VocabCache vocabCache) {
            this.cache = vocabCache;
            return this;
        }

        public Builder coOCurreneCounts(CounterMap<String, String> counterMap) {
            this.coOCurreneCounts = counterMap;
            return this;
        }

        public CoOccurrences build() {
            if (this.cache == null) {
                throw new IllegalArgumentException("Vocab cache must not be null!");
            }
            if (this.sentenceIterator == null) {
                throw new IllegalArgumentException("Sentence iterator must not be null");
            }
            return new CoOccurrences(this.tokenizerFactory, this.sentenceIterator, this.windowSize, this.cache, this.coOCurreneCounts, this.symmetric);
        }
    }

    /* loaded from: input_file:org/deeplearning4j/models/glove/CoOccurrences$CoOccurrenceBatchIterator.class */
    public class CoOccurrenceBatchIterator implements Iterator<List<Pair<VocabWord, VocabWord>>> {
        private Iterator<Pair<VocabWord, VocabWord>> iter;
        private int batchSize;

        public CoOccurrenceBatchIterator(int i) {
            this.iter = CoOccurrences.this.coOccurrenceIteratorVocab();
            this.batchSize = 100;
            this.batchSize = i;
        }

        public CoOccurrenceBatchIterator(CoOccurrences coOccurrences) {
            this(100);
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.iter.hasNext();
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public List<Pair<VocabWord, VocabWord>> next() {
            ArrayList arrayList = new ArrayList(this.batchSize);
            for (int i = 0; i < this.batchSize && this.iter.hasNext(); i++) {
                arrayList.add(this.iter.next());
            }
            return arrayList;
        }

        @Override // java.util.Iterator
        public void remove() {
            throw new UnsupportedOperationException();
        }
    }

    /* loaded from: input_file:org/deeplearning4j/models/glove/CoOccurrences$CoOccurrenceIterator.class */
    public class CoOccurrenceIterator implements Iterator<Pair<VocabWord, VocabWord>> {
        private Iterator<Pair<String, String>> iter;

        public CoOccurrenceIterator() {
            this.iter = CoOccurrences.this.coOccurrenceIterator();
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.iter.hasNext();
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public Pair<VocabWord, VocabWord> next() {
            Pair<String, String> next = this.iter.next();
            return new Pair<>(CoOccurrences.this.cache.wordFor((String) next.getFirst()), CoOccurrences.this.cache.wordFor((String) next.getSecond()));
        }

        @Override // java.util.Iterator
        public void remove() {
            throw new UnsupportedOperationException();
        }
    }

    private CoOccurrences() {
        this.windowSize = 15;
        this.symmetric = true;
        this.sentenceOccurrences = Util.parallelCounter();
        this.coOCurreneCounts = Util.parallelCounterMap();
    }

    public CoOccurrences(TokenizerFactory tokenizerFactory, SentenceIterator sentenceIterator, int i, VocabCache vocabCache, CounterMap<String, String> counterMap, boolean z) {
        this.windowSize = 15;
        this.symmetric = true;
        this.sentenceOccurrences = Util.parallelCounter();
        this.coOCurreneCounts = Util.parallelCounterMap();
        this.tokenizerFactory = tokenizerFactory;
        this.sentenceIterator = sentenceIterator;
        this.windowSize = i;
        this.cache = vocabCache;
        this.coOCurreneCounts = counterMap;
        this.symmetric = z;
    }

    public void fit() {
        if (this.trainingSystem == null) {
            this.trainingSystem = ActorSystem.create();
        }
        AtomicInteger atomicInteger = new AtomicInteger(0);
        ActorRef actorOf = this.trainingSystem.actorOf(new RoundRobinPool(Runtime.getRuntime().availableProcessors()).props(Props.create(CoOccurrenceActor.class, new Object[]{atomicInteger, this.tokenizerFactory, Integer.valueOf(this.windowSize), this.cache, this.coOCurreneCounts, Boolean.valueOf(this.symmetric), this.sentenceOccurrences})));
        this.sentenceIterator.reset();
        AtomicInteger atomicInteger2 = new AtomicInteger(0);
        int i = 0;
        while (this.sentenceIterator.hasNext()) {
            actorOf.tell(new SentenceWork(i, this.sentenceIterator.nextSentence()), actorOf);
            i++;
            atomicInteger2.incrementAndGet();
        }
        try {
            Thread.sleep(5000L);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        while (atomicInteger.get() < atomicInteger2.get()) {
            try {
                Thread.sleep(10000L);
            } catch (InterruptedException e2) {
                e2.printStackTrace();
            }
        }
        this.trainingSystem.shutdown();
        this.trainingSystem = null;
        log.info("Done processing co occurrences: ended with " + numCoOccurrences());
    }

    public Iterator<List<Pair<VocabWord, VocabWord>>> coOccurrenceIteratorVocabBatch(int i) {
        return new CoOccurrenceBatchIterator(i);
    }

    public Iterator<Pair<VocabWord, VocabWord>> coOccurrenceIteratorVocab() {
        return new CoOccurrenceIterator();
    }

    public static CoOccurrences load(InputStream inputStream) {
        CoOccurrences coOccurrences = new CoOccurrences();
        coOccurrences.coOccurrences = new ArrayList();
        CounterMap<String, String> counterMap = new CounterMap<>();
        LineIterator lineIterator = IOUtils.lineIterator(new InputStreamReader(inputStream));
        while (lineIterator.hasNext()) {
            String[] split = lineIterator.nextLine().split(" ");
            if (split.length >= 3 && !split[0].isEmpty() && !split[1].isEmpty()) {
                coOccurrences.coOccurrences.add(new Pair<>(split[0], split[1]));
                counterMap.incrementCount(split[0], split[1], Double.parseDouble(split[2]));
            }
        }
        coOccurrences.coOCurreneCounts = counterMap;
        return coOccurrences;
    }

    public Counter<Integer> getSentenceOccurrences() {
        return this.sentenceOccurrences;
    }

    public void setSentenceOccurrences(Counter<Integer> counter) {
        this.sentenceOccurrences = counter;
    }

    public List<Pair<String, String>> coOccurrenceList() {
        if (this.coOccurrences != null) {
            return this.coOccurrences;
        }
        Iterator<Pair<String, String>> coOccurrenceIterator = coOccurrenceIterator();
        ArrayList arrayList = new ArrayList();
        while (coOccurrenceIterator.hasNext()) {
            arrayList.add(coOccurrenceIterator.next());
        }
        return arrayList;
    }

    public List<Pair<String, String>> randomizedList() {
        List<Pair<String, String>> coOccurrenceList = coOccurrenceList();
        Collections.shuffle(coOccurrenceList);
        return coOccurrenceList;
    }

    public int numCoOccurrences() {
        return this.coOCurreneCounts.totalSize();
    }

    public double count(String str, String str2) {
        return this.coOCurreneCounts.getCount(str, str2);
    }

    public Iterator<Pair<String, String>> coOccurrenceIterator() {
        return this.coOCurreneCounts.getPairIterator();
    }

    public CounterMap<String, String> getCoOCurreneCounts() {
        return this.coOCurreneCounts;
    }

    public void setCoOCurreneCounts(CounterMap<String, String> counterMap) {
        this.coOCurreneCounts = counterMap;
    }
}
