package org.deeplearning4j.spark.models.sequencevectors.functions;

import java.util.Iterator;
import lombok.NonNull;
import org.apache.spark.Accumulator;
import org.apache.spark.api.java.function.Function;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.spark.models.sequencevectors.primitives.ExtraCounter;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/spark/models/sequencevectors/functions/ExtraCountFunction.class */
public class ExtraCountFunction<T extends SequenceElement> implements Function<Sequence<T>, Pair<Sequence<T>, Long>> {
    protected Accumulator<ExtraCounter<Long>> accumulator;
    protected boolean fetchLabels;

    public ExtraCountFunction(@NonNull Accumulator<ExtraCounter<Long>> accumulator, boolean z) {
        if (accumulator == null) {
            throw new NullPointerException("accumulator is marked @NonNull but is null");
        }
        this.accumulator = accumulator;
        this.fetchLabels = z;
    }

    public Pair<Sequence<T>, Long> call(Sequence<T> sequence) throws Exception {
        ExtraCounter extraCounter = new ExtraCounter();
        long j = 0;
        for (SequenceElement sequenceElement : sequence.getElements()) {
            if (sequenceElement != null) {
                extraCounter.incrementCount(sequenceElement.getStorageId(), 1.0d);
                j++;
            }
        }
        if (sequence.getSequenceLabels() != null) {
            Iterator it = sequence.getSequenceLabels().iterator();
            while (it.hasNext()) {
                extraCounter.incrementCount(((SequenceElement) it.next()).getStorageId(), 1.0d);
            }
        }
        extraCounter.buildNetworkSnapshot();
        this.accumulator.add(extraCounter);
        return Pair.makePair(sequence, Long.valueOf(j));
    }
}
