package org.apache.mahout.classifier.sgd;

import com.google.common.collect.Lists;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.ExecutionException;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.ep.EvolutionaryProcess;
import org.apache.mahout.ep.Mapping;
import org.apache.mahout.ep.Payload;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.stats.OnlineAuc;

/* loaded from: input_file:WEB-INF/lib/mahout-core-0.7.jar:org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.class */
public class AdaptiveLogisticRegression implements OnlineLearner, Writable {
    public static final int DEFAULT_THREAD_COUNT = 20;
    public static final int DEFAULT_POOL_SIZE = 20;
    private static final int SURVIVORS = 2;
    private int record;
    private int cutoff;
    private int minInterval;
    private int maxInterval;
    private int currentStep;
    private int bufferSize;
    private List<TrainingExample> buffer;
    private EvolutionaryProcess<Wrapper, CrossFoldLearner> ep;
    private State<Wrapper, CrossFoldLearner> best;
    private int threadCount;
    private int poolSize;
    private State<Wrapper, CrossFoldLearner> seed;
    private int numFeatures;
    private boolean freezeSurvivors;

    /* loaded from: input_file:WEB-INF/lib/mahout-core-0.7.jar:org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression$TrainingExample.class */
    public static class TrainingExample implements Writable {
        private long key;
        private String groupKey;
        private int actual;
        private Vector instance;

        private TrainingExample() {
        }

        public TrainingExample(long j, String str, int i, Vector vector) {
            this.key = j;
            this.groupKey = str;
            this.actual = i;
            this.instance = vector;
        }

        public long getKey() {
            return this.key;
        }

        public int getActual() {
            return this.actual;
        }

        public Vector getInstance() {
            return this.instance;
        }

        public String getGroupKey() {
            return this.groupKey;
        }

        public void write(DataOutput dataOutput) throws IOException {
            dataOutput.writeLong(this.key);
            if (this.groupKey != null) {
                dataOutput.writeBoolean(true);
                dataOutput.writeUTF(this.groupKey);
            } else {
                dataOutput.writeBoolean(false);
            }
            dataOutput.writeInt(this.actual);
            VectorWritable.writeVector(dataOutput, this.instance, true);
        }

        public void readFields(DataInput dataInput) throws IOException {
            this.key = dataInput.readLong();
            if (dataInput.readBoolean()) {
                this.groupKey = dataInput.readUTF();
            }
            this.actual = dataInput.readInt();
            this.instance = VectorWritable.readVector(dataInput);
        }
    }

    /* loaded from: input_file:WEB-INF/lib/mahout-core-0.7.jar:org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression$Wrapper.class */
    public static class Wrapper implements Payload<CrossFoldLearner> {
        private CrossFoldLearner wrapped;

        public Wrapper() {
        }

        public Wrapper(int i, int i2, PriorFunction priorFunction) {
            this.wrapped = new CrossFoldLearner(5, i, i2, priorFunction);
        }

        @Override // org.apache.mahout.ep.Payload
        /* renamed from: copy, reason: merged with bridge method [inline-methods] */
        public Payload<CrossFoldLearner> copy2() {
            Wrapper wrapper = new Wrapper();
            wrapper.wrapped = this.wrapped.copy();
            return wrapper;
        }

        @Override // org.apache.mahout.ep.Payload
        public void update(double[] dArr) {
            this.wrapped.lambda(dArr[0]);
            this.wrapped.learningRate(dArr[0 + 1]);
            this.wrapped.stepOffset(1);
            this.wrapped.alpha(1.0d);
            this.wrapped.decayExponent(0.0d);
        }

        public void freeze(State<Wrapper, CrossFoldLearner> state) {
            double[] params = state.getParams();
            params[1] = params[1] - 10.0d;
            state.setOmni(state.getOmni() / 20.0d);
            double[] step = state.getStep();
            for (int i = 0; i < step.length; i++) {
                int i2 = i;
                step[i2] = step[i2] / 20.0d;
            }
        }

        public void setMappings(State<Wrapper, CrossFoldLearner> state) {
            state.setMap(0, Mapping.logLimit(1.0E-8d, 0.1d));
            state.setMap(0 + 1, Mapping.logLimit(1.0E-8d, 1.0d));
        }

        public void train(TrainingExample trainingExample) {
            this.wrapped.train(trainingExample.getKey(), trainingExample.getGroupKey(), trainingExample.getActual(), trainingExample.getInstance());
        }

        public CrossFoldLearner getLearner() {
            return this.wrapped;
        }

        public String toString() {
            return String.format(Locale.ENGLISH, "auc=%.2f", Double.valueOf(this.wrapped.auc()));
        }

        public void setAucEvaluator(OnlineAuc onlineAuc) {
            this.wrapped.setAucEvaluator(onlineAuc);
        }

        public void write(DataOutput dataOutput) throws IOException {
            this.wrapped.write(dataOutput);
        }

        public void readFields(DataInput dataInput) throws IOException {
            this.wrapped = new CrossFoldLearner();
            this.wrapped.readFields(dataInput);
        }
    }

    public AdaptiveLogisticRegression() {
        this.cutoff = 1000;
        this.minInterval = 1000;
        this.maxInterval = 1000;
        this.currentStep = 1000;
        this.bufferSize = 1000;
        this.buffer = Lists.newArrayList();
        this.threadCount = 20;
        this.poolSize = 20;
        this.freezeSurvivors = true;
    }

    public AdaptiveLogisticRegression(int i, int i2, PriorFunction priorFunction) {
        this(i, i2, priorFunction, 20, 20);
    }

    public AdaptiveLogisticRegression(int i, int i2, PriorFunction priorFunction, int i3, int i4) {
        this.cutoff = 1000;
        this.minInterval = 1000;
        this.maxInterval = 1000;
        this.currentStep = 1000;
        this.bufferSize = 1000;
        this.buffer = Lists.newArrayList();
        this.threadCount = 20;
        this.poolSize = 20;
        this.freezeSurvivors = true;
        this.numFeatures = i2;
        this.threadCount = i3;
        this.poolSize = i4;
        this.seed = new State<>(new double[2], 10.0d);
        Wrapper wrapper = new Wrapper(i, i2, priorFunction);
        this.seed.setPayload(wrapper);
        wrapper.setMappings(this.seed);
        this.seed.setPayload(wrapper);
        setPoolSize(this.poolSize);
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(int i, Vector vector) {
        train(this.record, null, i, vector);
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(long j, int i, Vector vector) {
        train(j, null, i, vector);
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(long j, String str, int i, Vector vector) {
        this.record++;
        this.buffer.add(new TrainingExample(j, str, i, vector));
        if (this.buffer.size() > this.bufferSize) {
            trainWithBufferedExamples();
        }
    }

    private void trainWithBufferedExamples() {
        try {
            this.best = this.ep.parallelDo(new EvolutionaryProcess.Function<Payload<CrossFoldLearner>>() { // from class: org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.1
                @Override // org.apache.mahout.ep.EvolutionaryProcess.Function
                public double apply(Payload<CrossFoldLearner> payload, double[] dArr) {
                    Wrapper wrapper = (Wrapper) payload;
                    Iterator it = AdaptiveLogisticRegression.this.buffer.iterator();
                    while (it.hasNext()) {
                        wrapper.train((TrainingExample) it.next());
                    }
                    if (wrapper.getLearner().validModel()) {
                        return wrapper.getLearner().numCategories() == 2 ? wrapper.wrapped.auc() : wrapper.wrapped.logLikelihood();
                    }
                    return Double.NaN;
                }
            });
        } catch (InterruptedException e) {
        } catch (ExecutionException e2) {
            throw new IllegalStateException(e2.getCause());
        }
        this.buffer.clear();
        if (this.record > this.cutoff) {
            this.cutoff = nextStep(this.record);
            this.ep.mutatePopulation(2);
            if (this.freezeSurvivors) {
                for (State<Wrapper, CrossFoldLearner> state : this.ep.getPopulation().subList(0, 2)) {
                    state.getPayload().freeze(state);
                }
            }
        }
    }

    public int nextStep(int i) {
        int stepSize = stepSize(i, 2.6d);
        if (stepSize < this.minInterval) {
            stepSize = this.minInterval;
        }
        if (stepSize > this.maxInterval) {
            stepSize = this.maxInterval;
        }
        int i2 = stepSize * ((i / stepSize) + 1);
        if (i2 < this.cutoff + this.currentStep) {
            i2 = this.cutoff + this.currentStep;
        } else {
            this.currentStep = stepSize;
        }
        return i2;
    }

    public static int stepSize(int i, double d) {
        int[] iArr = {1, 2, 5};
        double floor = Math.floor(d * Math.log10(i));
        return iArr[((int) floor) % iArr.length] * ((int) Math.pow(10.0d, Math.floor(floor / iArr.length)));
    }

    @Override // org.apache.mahout.classifier.OnlineLearner, java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        trainWithBufferedExamples();
        try {
            this.ep.parallelDo(new EvolutionaryProcess.Function<Payload<CrossFoldLearner>>() { // from class: org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.2
                @Override // org.apache.mahout.ep.EvolutionaryProcess.Function
                public double apply(Payload<CrossFoldLearner> payload, double[] dArr) {
                    CrossFoldLearner learner = ((Wrapper) payload).getLearner();
                    learner.close();
                    return learner.logLikelihood();
                }
            });
            this.ep.close();
        } catch (InterruptedException e) {
        } catch (ExecutionException e2) {
            throw new IllegalStateException(e2);
        }
    }

    public void setInterval(int i) {
        setInterval(i, i);
    }

    public void setInterval(int i, int i2) {
        this.minInterval = Math.max(200, i);
        this.maxInterval = Math.max(200, i2);
        this.cutoff = i * ((this.record / i) + 1);
        this.currentStep = i;
        this.bufferSize = Math.min(i, this.bufferSize);
    }

    public void setPoolSize(int i) {
        this.poolSize = i;
        setupOptimizer(i);
    }

    public void setThreadCount(int i) {
        this.threadCount = i;
        setupOptimizer(this.poolSize);
    }

    public void setAucEvaluator(OnlineAuc onlineAuc) {
        this.seed.getPayload().setAucEvaluator(onlineAuc);
        setupOptimizer(this.poolSize);
    }

    private void setupOptimizer(int i) {
        this.ep = new EvolutionaryProcess<>(this.threadCount, i, this.seed);
    }

    public int numFeatures() {
        return this.numFeatures;
    }

    public double auc() {
        if (this.best == null) {
            return Double.NaN;
        }
        return this.best.getPayload().getLearner().auc();
    }

    public State<Wrapper, CrossFoldLearner> getBest() {
        return this.best;
    }

    public void setBest(State<Wrapper, CrossFoldLearner> state) {
        this.best = state;
    }

    public int getRecord() {
        return this.record;
    }

    public void setRecord(int i) {
        this.record = i;
    }

    public int getMinInterval() {
        return this.minInterval;
    }

    public int getMaxInterval() {
        return this.maxInterval;
    }

    public int getNumCategories() {
        return this.seed.getPayload().getLearner().numCategories();
    }

    public PriorFunction getPrior() {
        return this.seed.getPayload().getLearner().getPrior();
    }

    public void setBuffer(List<TrainingExample> list) {
        this.buffer = list;
    }

    public List<TrainingExample> getBuffer() {
        return this.buffer;
    }

    public EvolutionaryProcess<Wrapper, CrossFoldLearner> getEp() {
        return this.ep;
    }

    public void setEp(EvolutionaryProcess<Wrapper, CrossFoldLearner> evolutionaryProcess) {
        this.ep = evolutionaryProcess;
    }

    public State<Wrapper, CrossFoldLearner> getSeed() {
        return this.seed;
    }

    public void setSeed(State<Wrapper, CrossFoldLearner> state) {
        this.seed = state;
    }

    public int getNumFeatures() {
        return this.numFeatures;
    }

    public void setAveragingWindow(int i) {
        this.seed.getPayload().getLearner().setWindowSize(i);
        setupOptimizer(this.poolSize);
    }

    public void setFreezeSurvivors(boolean z) {
        this.freezeSurvivors = z;
    }

    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(this.record);
        dataOutput.writeInt(this.cutoff);
        dataOutput.writeInt(this.minInterval);
        dataOutput.writeInt(this.maxInterval);
        dataOutput.writeInt(this.currentStep);
        dataOutput.writeInt(this.bufferSize);
        dataOutput.writeInt(this.buffer.size());
        Iterator<TrainingExample> it = this.buffer.iterator();
        while (it.hasNext()) {
            it.next().write(dataOutput);
        }
        this.ep.write(dataOutput);
        this.best.write(dataOutput);
        dataOutput.writeInt(this.threadCount);
        dataOutput.writeInt(this.poolSize);
        this.seed.write(dataOutput);
        dataOutput.writeInt(this.numFeatures);
        dataOutput.writeBoolean(this.freezeSurvivors);
    }

    public void readFields(DataInput dataInput) throws IOException {
        this.record = dataInput.readInt();
        this.cutoff = dataInput.readInt();
        this.minInterval = dataInput.readInt();
        this.maxInterval = dataInput.readInt();
        this.currentStep = dataInput.readInt();
        this.bufferSize = dataInput.readInt();
        int readInt = dataInput.readInt();
        this.buffer = Lists.newArrayList();
        for (int i = 0; i < readInt; i++) {
            TrainingExample trainingExample = new TrainingExample();
            trainingExample.readFields(dataInput);
            this.buffer.add(trainingExample);
        }
        this.ep = new EvolutionaryProcess<>();
        this.ep.readFields(dataInput);
        this.best = new State<>();
        this.best.readFields(dataInput);
        this.threadCount = dataInput.readInt();
        this.poolSize = dataInput.readInt();
        this.seed = new State<>();
        this.seed.readFields(dataInput);
        this.numFeatures = dataInput.readInt();
        this.freezeSurvivors = dataInput.readBoolean();
    }
}
