package org.nd4j.parameterserver.distributed.messages.intercom;

import org.apache.camel.util.URISupport;
import org.apache.commons.compress.compressors.bzip2.BZip2Constants;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.enums.ExecutionMode;
import org.nd4j.parameterserver.distributed.logic.storage.WordVectorStorage;
import org.nd4j.parameterserver.distributed.messages.BaseVoidMessage;
import org.nd4j.parameterserver.distributed.messages.DistributedMessage;
import org.nd4j.parameterserver.distributed.messages.aggregations.InitializationAggregation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
/* loaded from: input_file:org/nd4j/parameterserver/distributed/messages/intercom/DistributedInitializationMessage.class */
public class DistributedInitializationMessage extends BaseVoidMessage implements DistributedMessage {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) DistributedInitializationMessage.class);
    protected int vectorLength;
    protected int numWords;
    protected long seed;
    protected boolean useHs;
    protected boolean useNeg;
    protected int columnsPerShard;

    /* loaded from: input_file:org/nd4j/parameterserver/distributed/messages/intercom/DistributedInitializationMessage$DistributedInitializationMessageBuilder.class */
    public static class DistributedInitializationMessageBuilder {
        private int vectorLength;
        private int numWords;
        private long seed;
        private boolean useHs;
        private boolean useNeg;
        private int columnsPerShard;

        DistributedInitializationMessageBuilder() {
        }

        public DistributedInitializationMessageBuilder vectorLength(int i) {
            this.vectorLength = i;
            return this;
        }

        public DistributedInitializationMessageBuilder numWords(int i) {
            this.numWords = i;
            return this;
        }

        public DistributedInitializationMessageBuilder seed(long j) {
            this.seed = j;
            return this;
        }

        public DistributedInitializationMessageBuilder useHs(boolean z) {
            this.useHs = z;
            return this;
        }

        public DistributedInitializationMessageBuilder useNeg(boolean z) {
            this.useNeg = z;
            return this;
        }

        public DistributedInitializationMessageBuilder columnsPerShard(int i) {
            this.columnsPerShard = i;
            return this;
        }

        public DistributedInitializationMessage build() {
            return new DistributedInitializationMessage(this.vectorLength, this.numWords, this.seed, this.useHs, this.useNeg, this.columnsPerShard);
        }

        public String toString() {
            return "DistributedInitializationMessage.DistributedInitializationMessageBuilder(vectorLength=" + this.vectorLength + ", numWords=" + this.numWords + ", seed=" + this.seed + ", useHs=" + this.useHs + ", useNeg=" + this.useNeg + ", columnsPerShard=" + this.columnsPerShard + URISupport.RAW_TOKEN_END;
        }
    }

    public DistributedInitializationMessage(int i, int i2, long j, boolean z, boolean z2, int i3) {
        super(4);
        this.vectorLength = i;
        this.numWords = i2;
        this.seed = j;
        this.useHs = z;
        this.useNeg = z2;
        this.columnsPerShard = i3;
    }

    @Override // org.nd4j.parameterserver.distributed.messages.VoidMessage
    public void processMessage() {
        int numberOfShards;
        INDArray array = this.storage.getArray(WordVectorStorage.SYN_0);
        INDArray array2 = this.storage.getArray(WordVectorStorage.SYN_1);
        INDArray array3 = this.storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
        this.storage.getArray(WordVectorStorage.EXP_TABLE);
        if (array == null) {
            log.info("sI_{} is starting initialization...", Short.valueOf(this.transport.getShardIndex()));
            Nd4j.getRandom().setSeed(this.seed * (this.shardIndex + 1));
            if (this.voidConfiguration.getExecutionMode() == ExecutionMode.AVERAGING) {
                this.columnsPerShard = this.vectorLength;
            } else if (this.voidConfiguration.getExecutionMode() == ExecutionMode.SHARDED && this.voidConfiguration.getNumberOfShards() - 1 == this.shardIndex && (numberOfShards = this.vectorLength % this.voidConfiguration.getNumberOfShards()) != 0) {
                this.columnsPerShard += numberOfShards;
                log.info("Got inequal split. using higher number of elements: {}", Integer.valueOf(this.columnsPerShard));
            }
            int[] iArr = {this.numWords, this.columnsPerShard};
            INDArray divi = Nd4j.rand(iArr, 99L).subi(Double.valueOf(0.5d)).divi(Integer.valueOf(this.vectorLength));
            if (this.useHs) {
                array2 = Nd4j.create(iArr, 'c');
            }
            if (this.useNeg) {
                array3 = Nd4j.create(iArr, 'c');
            }
            INDArray initExpTable = initExpTable(BZip2Constants.BASEBLOCKSIZE);
            this.storage.setArray(WordVectorStorage.SYN_0, divi);
            if (this.useHs) {
                this.storage.setArray(WordVectorStorage.SYN_1, array2);
            }
            if (this.useNeg) {
                this.storage.setArray(WordVectorStorage.SYN_1_NEGATIVE, array3);
            }
            this.storage.setArray(WordVectorStorage.EXP_TABLE, initExpTable);
            InitializationAggregation initializationAggregation = new InitializationAggregation((short) this.voidConfiguration.getNumberOfShards(), this.transport.getShardIndex());
            initializationAggregation.setOriginatorId(this.originatorId);
            this.transport.sendMessage(initializationAggregation);
        }
    }

    protected INDArray initExpTable(int i) {
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            double exp = FastMath.exp((((i2 / dArr.length) * 2.0d) - 1.0d) * 6.0d);
            dArr[i2] = exp / (exp + 1.0d);
        }
        return Nd4j.create(dArr);
    }

    public static DistributedInitializationMessageBuilder builder() {
        return new DistributedInitializationMessageBuilder();
    }

    public DistributedInitializationMessage() {
    }

    public int getVectorLength() {
        return this.vectorLength;
    }

    public int getNumWords() {
        return this.numWords;
    }

    public long getSeed() {
        return this.seed;
    }

    public boolean isUseHs() {
        return this.useHs;
    }

    public boolean isUseNeg() {
        return this.useNeg;
    }

    public int getColumnsPerShard() {
        return this.columnsPerShard;
    }

    public void setVectorLength(int i) {
        this.vectorLength = i;
    }

    public void setNumWords(int i) {
        this.numWords = i;
    }

    public void setSeed(long j) {
        this.seed = j;
    }

    public void setUseHs(boolean z) {
        this.useHs = z;
    }

    public void setUseNeg(boolean z) {
        this.useNeg = z;
    }

    public void setColumnsPerShard(int i) {
        this.columnsPerShard = i;
    }

    @Override // org.nd4j.parameterserver.distributed.messages.BaseVoidMessage
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof DistributedInitializationMessage)) {
            return false;
        }
        DistributedInitializationMessage distributedInitializationMessage = (DistributedInitializationMessage) obj;
        return distributedInitializationMessage.canEqual(this) && getVectorLength() == distributedInitializationMessage.getVectorLength() && getNumWords() == distributedInitializationMessage.getNumWords() && getSeed() == distributedInitializationMessage.getSeed() && isUseHs() == distributedInitializationMessage.isUseHs() && isUseNeg() == distributedInitializationMessage.isUseNeg() && getColumnsPerShard() == distributedInitializationMessage.getColumnsPerShard();
    }

    @Override // org.nd4j.parameterserver.distributed.messages.BaseVoidMessage
    protected boolean canEqual(Object obj) {
        return obj instanceof DistributedInitializationMessage;
    }

    @Override // org.nd4j.parameterserver.distributed.messages.BaseVoidMessage
    public int hashCode() {
        int vectorLength = (((1 * 59) + getVectorLength()) * 59) + getNumWords();
        long seed = getSeed();
        return (((((((vectorLength * 59) + ((int) ((seed >>> 32) ^ seed))) * 59) + (isUseHs() ? 79 : 97)) * 59) + (isUseNeg() ? 79 : 97)) * 59) + getColumnsPerShard();
    }

    @Override // org.nd4j.parameterserver.distributed.messages.BaseVoidMessage
    public String toString() {
        return "DistributedInitializationMessage(vectorLength=" + getVectorLength() + ", numWords=" + getNumWords() + ", seed=" + getSeed() + ", useHs=" + isUseHs() + ", useNeg=" + isUseNeg() + ", columnsPerShard=" + getColumnsPerShard() + URISupport.RAW_TOKEN_END;
    }
}
