package org.deeplearning4j.spark.models.paragraphvectors;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.paragraphvectors.functions.DocumentSequenceConvertFunction;
import org.deeplearning4j.spark.models.paragraphvectors.functions.KeySequenceConvertFunction;
import org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors;
import org.deeplearning4j.text.documentiterator.LabelledDocument;

/* loaded from: input_file:org/deeplearning4j/spark/models/paragraphvectors/SparkParagraphVectors.class */
public class SparkParagraphVectors extends SparkSequenceVectors<VocabWord> {
    protected SparkParagraphVectors() {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors
    public VocabCache<ShallowSequenceElement> getShallowVocabCache() {
        return super.getShallowVocabCache();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors
    public void validateConfiguration() {
        super.validateConfiguration();
        if (this.configuration.getTokenizerFactory() == null) {
            throw new DL4JInvalidConfigException("TokenizerFactory is undefined. Can't train ParagraphVectors without it.");
        }
    }

    public void fitMultipleFiles(JavaPairRDD<String, String> javaPairRDD) {
        validateConfiguration();
        broadcastEnvironment(new JavaSparkContext(javaPairRDD.context()));
        super.fitSequences(javaPairRDD.map(new KeySequenceConvertFunction(this.configurationBroadcast)));
    }

    public void fitLabelledDocuments(JavaRDD<LabelledDocument> javaRDD) {
        validateConfiguration();
        broadcastEnvironment(new JavaSparkContext(javaRDD.context()));
        super.fitSequences(javaRDD.map(new DocumentSequenceConvertFunction(this.configurationBroadcast)));
    }
}
