package org.nd4j.linalg.dataset.api.preprocessor;

import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;
import org.nd4j.linalg.exception.ND4JIllegalStateException;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/AbstractDataSetNormalizer.class */
public abstract class AbstractDataSetNormalizer<S extends NormalizerStats> extends AbstractNormalizer implements DataNormalization {
    protected NormalizerStrategy<S> strategy;
    private S featureStats;
    private S labelStats;
    private boolean fitLabels = false;

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractDataSetNormalizer(NormalizerStrategy<S> normalizerStrategy) {
        this.strategy = normalizerStrategy;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void fitLabel(boolean z) {
        this.fitLabels = z;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public boolean isFitLabel() {
        return this.fitLabels;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.Normalizer
    public void fit(DataSet dataSet) {
        this.featureStats = newBuilder().addFeatures(dataSet).build();
        if (isFitLabel()) {
            this.labelStats = newBuilder().addLabels(dataSet).build();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public S getFeatureStats() {
        return this.featureStats;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public S getLabelStats() {
        return this.labelStats;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.AbstractNormalizer
    protected boolean isFit() {
        return this.featureStats != null;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void fit(DataSetIterator dataSetIterator) {
        NormalizerStats.Builder newBuilder = newBuilder();
        NormalizerStats.Builder newBuilder2 = newBuilder();
        dataSetIterator.reset();
        while (dataSetIterator.hasNext()) {
            org.nd4j.linalg.dataset.DataSet next = dataSetIterator.next();
            newBuilder.addFeatures(next);
            if (this.fitLabels) {
                newBuilder2.addLabels(next);
            }
        }
        this.featureStats = (S) newBuilder.build();
        if (this.fitLabels) {
            this.labelStats = (S) newBuilder2.build();
        }
        dataSetIterator.reset();
    }

    protected abstract NormalizerStats.Builder newBuilder();

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization, org.nd4j.linalg.dataset.api.DataSetPreProcessor
    public void preProcess(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("toPreProcess is marked non-null but is null");
        }
        transform(dataSet.getFeatures(), dataSet.getFeaturesMaskArray());
        transformLabel(dataSet.getLabels(), dataSet.getLabelsMaskArray());
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.Normalizer
    public void transform(DataSet dataSet) {
        preProcess(dataSet);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void transform(INDArray iNDArray) {
        transform(iNDArray, null);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void transform(INDArray iNDArray, INDArray iNDArray2) {
        S featureStats = getFeatureStats();
        if (featureStats == null) {
            throw new ND4JIllegalStateException("Features statistics were not yet calculated. Make sure to run fit() first.");
        }
        this.strategy.preProcess(iNDArray, iNDArray2, featureStats);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void transformLabel(INDArray iNDArray) {
        transformLabel(iNDArray, null);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void transformLabel(INDArray iNDArray, INDArray iNDArray2) {
        if (isFitLabel()) {
            this.strategy.preProcess(iNDArray, iNDArray2, getLabelStats());
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void revertFeatures(INDArray iNDArray) {
        revertFeatures(iNDArray, null);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void revertFeatures(INDArray iNDArray, INDArray iNDArray2) {
        this.strategy.revert(iNDArray, iNDArray2, getFeatureStats());
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void revertLabels(INDArray iNDArray) {
        revertLabels(iNDArray, null);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
    public void revertLabels(INDArray iNDArray, INDArray iNDArray2) {
        if (isFitLabel()) {
            this.strategy.revert(iNDArray, iNDArray2, getLabelStats());
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.Normalizer
    public void revert(DataSet dataSet) {
        revertFeatures(dataSet.getFeatures(), dataSet.getFeaturesMaskArray());
        revertLabels(dataSet.getLabels(), dataSet.getLabelsMaskArray());
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof AbstractDataSetNormalizer)) {
            return false;
        }
        AbstractDataSetNormalizer abstractDataSetNormalizer = (AbstractDataSetNormalizer) obj;
        if (!abstractDataSetNormalizer.canEqual(this)) {
            return false;
        }
        NormalizerStrategy<S> normalizerStrategy = this.strategy;
        NormalizerStrategy<S> normalizerStrategy2 = abstractDataSetNormalizer.strategy;
        if (normalizerStrategy == null) {
            if (normalizerStrategy2 != null) {
                return false;
            }
        } else if (!normalizerStrategy.equals(normalizerStrategy2)) {
            return false;
        }
        S featureStats = getFeatureStats();
        NormalizerStats featureStats2 = abstractDataSetNormalizer.getFeatureStats();
        if (featureStats == null) {
            if (featureStats2 != null) {
                return false;
            }
        } else if (!featureStats.equals(featureStats2)) {
            return false;
        }
        S labelStats = getLabelStats();
        NormalizerStats labelStats2 = abstractDataSetNormalizer.getLabelStats();
        if (labelStats == null) {
            if (labelStats2 != null) {
                return false;
            }
        } else if (!labelStats.equals(labelStats2)) {
            return false;
        }
        return this.fitLabels == abstractDataSetNormalizer.fitLabels;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof AbstractDataSetNormalizer;
    }

    public int hashCode() {
        NormalizerStrategy<S> normalizerStrategy = this.strategy;
        int hashCode = (1 * 59) + (normalizerStrategy == null ? 43 : normalizerStrategy.hashCode());
        S featureStats = getFeatureStats();
        int hashCode2 = (hashCode * 59) + (featureStats == null ? 43 : featureStats.hashCode());
        S labelStats = getLabelStats();
        return (((hashCode2 * 59) + (labelStats == null ? 43 : labelStats.hashCode())) * 59) + (this.fitLabels ? 79 : 97);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setFeatureStats(S s) {
        this.featureStats = s;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setLabelStats(S s) {
        this.labelStats = s;
    }
}
