package org.nd4j.parameterserver.distributed;

import java.net.InterfaceAddress;
import java.net.NetworkInterface;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicBoolean;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.enums.ExecutionMode;
import org.nd4j.parameterserver.distributed.enums.NodeRole;
import org.nd4j.parameterserver.distributed.logic.Storage;
import org.nd4j.parameterserver.distributed.logic.completion.Clipboard;
import org.nd4j.parameterserver.distributed.logic.sequence.BasicSequenceProvider;
import org.nd4j.parameterserver.distributed.logic.storage.WordVectorStorage;
import org.nd4j.parameterserver.distributed.messages.Frame;
import org.nd4j.parameterserver.distributed.messages.TrainingMessage;
import org.nd4j.parameterserver.distributed.messages.VoidMessage;
import org.nd4j.parameterserver.distributed.messages.requests.InitializationRequestMessage;
import org.nd4j.parameterserver.distributed.messages.requests.VectorRequestMessage;
import org.nd4j.parameterserver.distributed.training.TrainingDriver;
import org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer;
import org.nd4j.parameterserver.distributed.transport.RoutedTransport;
import org.nd4j.parameterserver.distributed.transport.Transport;
import org.nd4j.parameterserver.distributed.util.NetworkOrganizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/parameterserver/distributed/VoidParameterServer.class */
public class VoidParameterServer {
    protected volatile NodeRole nodeRole;
    protected volatile VoidConfiguration voidConfiguration;
    protected AtomicBoolean initLocker;
    protected AtomicBoolean initFinished;
    protected AtomicBoolean shutdownLocker;
    protected AtomicBoolean shutdownFinished;
    protected transient Transport transport;
    protected transient AtomicBoolean manualMode;
    protected transient AtomicBoolean runner;
    protected transient Thread[] processingThreads;
    protected transient Runnable[] processingRunnables;
    protected transient TrainingDriver<? extends TrainingMessage> trainer;
    protected short shardIndex;
    protected Clipboard clipboard;
    protected Storage storage;
    protected Map<String, Frame<TrainingMessage>> frames;
    protected ThreadPoolExecutor executor;
    private static final Logger log = LoggerFactory.getLogger(VoidParameterServer.class);
    private static final VoidParameterServer INSTANCE = new VoidParameterServer();
    protected static final int numThreads = Runtime.getRuntime().availableProcessors() * 2;
    protected static double MAX_EXP = 6.0d;

    protected VoidParameterServer() {
        this.initLocker = new AtomicBoolean(false);
        this.initFinished = new AtomicBoolean(false);
        this.shutdownLocker = new AtomicBoolean(false);
        this.shutdownFinished = new AtomicBoolean(false);
        this.manualMode = new AtomicBoolean(false);
        this.runner = new AtomicBoolean(false);
        this.clipboard = new Clipboard();
        this.storage = new WordVectorStorage();
        this.frames = new ConcurrentHashMap();
        this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
        this.nodeRole = NodeRole.NONE;
    }

    protected VoidParameterServer(@NonNull NodeRole nodeRole) {
        this.initLocker = new AtomicBoolean(false);
        this.initFinished = new AtomicBoolean(false);
        this.shutdownLocker = new AtomicBoolean(false);
        this.shutdownFinished = new AtomicBoolean(false);
        this.manualMode = new AtomicBoolean(false);
        this.runner = new AtomicBoolean(false);
        this.clipboard = new Clipboard();
        this.storage = new WordVectorStorage();
        this.frames = new ConcurrentHashMap();
        this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
        if (nodeRole == null) {
            throw new NullPointerException("nodeRole");
        }
        this.nodeRole = nodeRole;
    }

    protected VoidParameterServer(boolean z) {
        this();
        this.manualMode.set(z);
    }

    public static VoidParameterServer getInstance() {
        return INSTANCE;
    }

    public void setTrainingDriver(@NonNull TrainingDriver<? extends TrainingMessage> trainingDriver) {
        if (trainingDriver == null) {
            throw new NullPointerException("trainer");
        }
        this.trainer = trainingDriver;
    }

    public short getShardIndex() {
        return this.shardIndex;
    }

    protected void setIpPortForShard(String str, int i) {
        this.transport.setIpAndPort(str, i);
    }

    protected void setShardIndex(short s) {
        this.shardIndex = s;
    }

    protected Transport getTransport() {
        return this.transport;
    }

    protected INDArray getSyn0() {
        return this.storage.getArray(WordVectorStorage.SYN_0);
    }

    protected INDArray getSyn1() {
        return this.storage.getArray(WordVectorStorage.SYN_1);
    }

    protected INDArray getSyn1Neg() {
        return this.storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
    }

    protected INDArray getExpTable() {
        return this.storage.getArray(WordVectorStorage.EXP_TABLE);
    }

    protected INDArray getNegTable() {
        return this.storage.getArray(WordVectorStorage.NEGATIVE_TABLE);
    }

    protected void init(@NonNull VoidConfiguration voidConfiguration) {
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration");
        }
        init(voidConfiguration, new RoutedTransport(), new SkipGramTrainer());
    }

    public boolean isInit() {
        return this.initFinished.get();
    }

    public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport, TrainingDriver<? extends TrainingMessage> trainingDriver) {
        String str;
        int unicastPort;
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration");
        }
        if (transport == null) {
            throw new NullPointerException("transport");
        }
        if (this.initFinished.get()) {
            return;
        }
        synchronized (this) {
            if (this.initLocker.compareAndSet(false, true)) {
                this.trainer = trainingDriver;
                this.voidConfiguration = voidConfiguration;
                this.transport = transport;
                if (this.nodeRole == NodeRole.NONE && (voidConfiguration.getForcedRole() == null || voidConfiguration.getForcedRole() == NodeRole.NONE)) {
                    Pair<NodeRole, String> create = (voidConfiguration.getShardAddresses().size() == 1 && voidConfiguration.getShardAddresses().get(0).contains("127.0.0.1")) ? Pair.create(NodeRole.SHARD, voidConfiguration.getShardAddresses().get(0)) : getRole(voidConfiguration, getLocalAddresses());
                    this.nodeRole = (NodeRole) create.getFirst();
                    String str2 = (String) create.getSecond();
                    if (str2.contains(":")) {
                        String[] split = str2.split(":");
                        str = split[0];
                        unicastPort = Integer.valueOf(split[1]).intValue();
                    } else {
                        str = str2;
                        unicastPort = voidConfiguration.getUnicastPort();
                    }
                    if (this.nodeRole == NodeRole.SHARD && voidConfiguration.getShardAddresses().size() > 1) {
                        short s = 0;
                        for (String str3 : voidConfiguration.getShardAddresses()) {
                            if ((str3.contains(":") ? str2.split(":")[0] : str3).equals(str)) {
                                this.shardIndex = s;
                            }
                            s = (short) (s + 1);
                        }
                    }
                    this.transport.init(voidConfiguration, this.clipboard, this.nodeRole, str, unicastPort, this.shardIndex);
                } else {
                    if (this.nodeRole == NodeRole.NONE) {
                        this.nodeRole = voidConfiguration.getForcedRole();
                    }
                    this.transport.init(voidConfiguration, this.clipboard, this.nodeRole, voidConfiguration.getExecutionMode() == ExecutionMode.MANAGED ? voidConfiguration.getControllerAddress() : "127.0.0.1", voidConfiguration.getUnicastPort(), this.shardIndex);
                }
                if (!this.manualMode.get()) {
                    this.processingThreads = new Thread[numThreads];
                    this.processingRunnables = new Runnable[numThreads];
                    for (int i = 0; i < numThreads; i++) {
                        this.processingThreads[i] = new Thread(() -> {
                            this.runner.set(true);
                            while (this.runner.get()) {
                                try {
                                    handleMessage(transport.takeMessage());
                                } catch (Exception e) {
                                    throw new RuntimeException(e);
                                } catch (ND4JIllegalStateException e2) {
                                    throw new RuntimeException((Throwable) e2);
                                }
                            }
                        });
                        Nd4j.getAffinityManager().attachThreadToDevice(this.processingThreads[i], Nd4j.getAffinityManager().getDeviceForCurrentThread());
                        this.processingThreads[i].setDaemon(true);
                        this.processingThreads[i].setName("VoidParameterServer messages handling thread");
                        this.processingThreads[i].start();
                    }
                }
                log.info("Launching transport...");
                transport.launch(Transport.ThreadingModel.DEDICATED_THREADS);
                trainingDriver.init(this.voidConfiguration, this.transport, this.storage, this.clipboard);
                this.initFinished.set(true);
            }
        }
    }

    protected VoidParameterServer toggleManualMode(boolean z) {
        this.manualMode.set(z);
        return this;
    }

    protected Pair<NodeRole, String> getRole(@NonNull VoidConfiguration voidConfiguration, @NonNull Collection<String> collection) {
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration");
        }
        if (collection == null) {
            throw new NullPointerException("localIPs");
        }
        NodeRole nodeRole = NodeRole.CLIENT;
        for (String str : voidConfiguration.getShardAddresses()) {
            if (collection.contains(str.replaceAll(":.*", ""))) {
                return Pair.create(NodeRole.SHARD, str);
            }
        }
        if (voidConfiguration.getBackupAddresses() != null) {
            for (String str2 : voidConfiguration.getBackupAddresses()) {
                if (collection.contains(str2.replaceAll(":.*", ""))) {
                    return Pair.create(NodeRole.BACKUP, str2);
                }
            }
        }
        String str3 = System.getenv("SPARK_PUBLIC_DNS");
        if (str3 == null && voidConfiguration.getNetworkMask() != null) {
            str3 = new NetworkOrganizer(voidConfiguration.getNetworkMask()).getMatchingAddress();
        }
        if (str3 == null) {
            str3 = System.getenv("DL4J_VOID_IP");
        }
        log.info("Got [{}] as sparkIp", str3);
        if (str3 == null) {
            throw new ND4JIllegalStateException("Can't get IP address for UDP communcation");
        }
        return Pair.create(nodeRole, str3 + ":" + voidConfiguration.getUnicastPort());
    }

    public void shutdown() {
        if (this.initLocker.get() && this.shutdownLocker.compareAndSet(false, true)) {
            log.info("Shutting down transport...");
            this.transport.shutdown();
            this.executor.shutdown();
        }
    }

    public static Set<String> getLocalAddresses() {
        try {
            ArrayList<NetworkInterface> list = Collections.list(NetworkInterface.getNetworkInterfaces());
            HashSet hashSet = new HashSet();
            for (NetworkInterface networkInterface : list) {
                if (!networkInterface.isLoopback() && networkInterface.isUp()) {
                    Iterator<InterfaceAddress> it = networkInterface.getInterfaceAddresses().iterator();
                    while (it.hasNext()) {
                        String hostAddress = it.next().getAddress().getHostAddress();
                        if (hostAddress != null && !hostAddress.isEmpty() && !hostAddress.contains(":")) {
                            hashSet.add(hostAddress);
                        }
                    }
                }
            }
            return hashSet;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    protected void handleMessage(@NonNull VoidMessage voidMessage) {
        if (voidMessage == null) {
            throw new NullPointerException("message");
        }
        if (voidMessage == null) {
            return;
        }
        if (voidMessage.getTargetId() >= 0 && voidMessage.getTargetId() != this.shardIndex) {
            log.warn("sI_{}: Skipping message: [{}]; TargetIdx: [{}]", new Object[]{Short.valueOf(this.shardIndex), voidMessage.getClass().getSimpleName(), Short.valueOf(voidMessage.getTargetId())});
        } else {
            voidMessage.attachContext(this.voidConfiguration, this.trainer, this.clipboard, this.transport, this.storage, this.nodeRole, this.shardIndex);
            voidMessage.processMessage();
        }
    }

    public void initializeSeqVec(int i, int i2, long j, int i3, boolean z, boolean z2) {
        this.transport.sendMessage(new InitializationRequestMessage(i, i2, j, z, z2, i3));
    }

    public synchronized void execDistributed(@NonNull TrainingMessage trainingMessage) {
        if (trainingMessage == null) {
            throw new NullPointerException("message");
        }
        Frame<TrainingMessage> frame = this.frames.get(trainingMessage.getClass().getSimpleName());
        Frame<TrainingMessage> frame2 = frame;
        if (frame == null) {
            frame2 = new Frame<>(BasicSequenceProvider.getInstance().getNextValue().longValue());
            this.frames.put(trainingMessage.getClass().getSimpleName(), frame2);
        }
        frame2.stackMessage(trainingMessage);
        if (frame2.size() >= 128) {
            this.transport.sendMessage(frame2);
            this.frames.put(trainingMessage.getClass().getSimpleName(), new Frame<>(BasicSequenceProvider.getInstance().getNextValue().longValue()));
        }
    }

    public void execDistributedImmediately(@NonNull TrainingMessage trainingMessage) {
        if (trainingMessage == null) {
            throw new NullPointerException("message");
        }
        this.transport.sendMessageToAllShards(trainingMessage);
    }

    public void execDistributed(@NonNull Frame<? extends TrainingMessage> frame) {
        if (frame == null) {
            throw new NullPointerException("messages");
        }
        this.transport.sendMessage(frame);
    }

    public INDArray getVector(int i) {
        return getVector(WordVectorStorage.SYN_0, i);
    }

    public INDArray getVector(@NonNull Integer num, int i) {
        if (num == null) {
            throw new NullPointerException("key");
        }
        return this.transport.sendMessageAndGetResponse(new VectorRequestMessage(num, i)).getPayload();
    }

    public synchronized void sendMessageToAllShards(@NonNull VoidMessage voidMessage) {
        if (voidMessage == null) {
            throw new NullPointerException("message");
        }
        this.transport.sendMessageToAllShards(voidMessage);
    }

    public void sendMessageToAllClients(@NonNull VoidMessage voidMessage) {
        if (voidMessage == null) {
            throw new NullPointerException("message");
        }
        sendMessageToAllClients(voidMessage, null);
    }

    public synchronized void sendMessageToAllClients(@NonNull VoidMessage voidMessage, Long... lArr) {
        if (voidMessage == null) {
            throw new NullPointerException("message");
        }
        this.transport.sendMessageToAllClients(voidMessage, new Long[0]);
    }

    public NodeRole getNodeRole() {
        return this.nodeRole;
    }
}
