package org.deeplearning4j.iterativereduce.tracker.statetracker.hazelcast;

import com.hazelcast.client.HazelcastClient;
import com.hazelcast.client.config.ClientConfig;
import com.hazelcast.config.Config;
import com.hazelcast.config.JoinConfig;
import com.hazelcast.config.ListConfig;
import com.hazelcast.config.MapConfig;
import com.hazelcast.core.Hazelcast;
import com.hazelcast.core.HazelcastInstance;
import com.hazelcast.core.IAtomicReference;
import com.hazelcast.core.IList;
import com.hazelcast.core.MemberAttributeEvent;
import com.hazelcast.core.MembershipEvent;
import com.hazelcast.core.MembershipListener;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.iterativereduce.actor.core.Job;
import org.deeplearning4j.iterativereduce.actor.util.PortTaken;
import org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker;
import org.deeplearning4j.scaleout.iterativereduce.multi.UpdateableImpl;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/tracker/statetracker/hazelcast/HazelCastStateTracker.class */
public class HazelCastStateTracker implements StateTracker<UpdateableImpl> {
    private static final long serialVersionUID = -7374372180080957334L;
    public static final String JOBS = "org.deeplearning4j.jobs";
    public static final String NUM_TIMES_PRETRAIN_RAN = "pretrainran";
    public static final String WORKERS = "org.deeplearning4j.workers";
    public static final String AVAILABLE_WORKERS = "AVAILABLE_WORKERS";
    public static final String NUM_TIMES_RUN_PRETRAIN = "PRETRAIN";
    public static final String TOPICS = "topics";
    public static final String RESULT = "RESULT";
    public static final String DONE = "done";
    public static final String LOCKS = "LOCKS";
    public static final String HEART_BEAT = "heartbeat";
    public static final String IS_PRETRAIN = "ispretrain";
    public static final String RESULT_LOC = "RESULT_LOC";
    private volatile transient IAtomicReference<Object> master;
    private volatile transient IList<Job> jobs;
    private volatile transient IAtomicReference<Integer> numTimesPretrain;
    private volatile transient IAtomicReference<Integer> numTimesPretrainRan;
    private volatile transient IAtomicReference<Boolean> done;
    private volatile transient IList<String> workers;
    private volatile transient IList<String> topics;
    private volatile IAtomicReference<Object> isPretrain;
    private static Logger log = LoggerFactory.getLogger(HazelCastStateTracker.class);
    private transient Config config;
    public static final int DEFAULT_HAZELCAST_PORT = 2510;
    public static final String CURRENT_JOBS = "JOBS";
    private transient HazelcastInstance h;
    private String type;
    private int hazelCastPort;
    private Map<String, Long> heartbeat;

    public HazelCastStateTracker() throws Exception {
        this(DEFAULT_HAZELCAST_PORT);
    }

    public HazelCastStateTracker(int i) throws Exception {
        this("master", "master", i);
    }

    public HazelCastStateTracker(String str, String str2, int i) throws Exception {
        this.hazelCastPort = -1;
        if (!str2.equals("master") || PortTaken.portTaken(i)) {
            log.info("Connecting to hazelcast on " + str);
            ClientConfig clientConfig = new ClientConfig();
            clientConfig.getNetworkConfig().addAddress(new String[]{str});
            this.h = HazelcastClient.newHazelcastClient(clientConfig);
        } else {
            this.hazelCastPort = i;
            this.config = hazelcast();
            this.h = Hazelcast.newHazelcastInstance(this.config);
            this.h.getCluster().addMembershipListener(new MembershipListener() { // from class: org.deeplearning4j.iterativereduce.tracker.statetracker.hazelcast.HazelCastStateTracker.1
                public void memberAdded(MembershipEvent membershipEvent) {
                    HazelCastStateTracker.log.info("Member added " + membershipEvent.toString());
                }

                public void memberRemoved(MembershipEvent membershipEvent) {
                    HazelCastStateTracker.log.info("Member removed " + membershipEvent.toString());
                }

                public void memberAttributeChanged(MemberAttributeEvent memberAttributeEvent) {
                    HazelCastStateTracker.log.info("Member changed " + memberAttributeEvent.toString());
                }
            });
        }
        this.type = str2;
        this.jobs = this.h.getList(JOBS);
        this.workers = this.h.getList(WORKERS);
        if (!this.type.equals("master")) {
            while (this.workers.isEmpty()) {
                log.warn("Waiting for data sync...");
                Thread.sleep(1000L);
            }
            log.info("Workers is " + this.workers.size());
        }
        this.topics = this.h.getList("topics");
        this.heartbeat = this.h.getMap(HEART_BEAT);
        this.master = this.h.getAtomicReference(RESULT);
        this.isPretrain = this.h.getAtomicReference(IS_PRETRAIN);
        this.numTimesPretrain = this.h.getAtomicReference(NUM_TIMES_RUN_PRETRAIN);
        this.numTimesPretrainRan = this.h.getAtomicReference(NUM_TIMES_PRETRAIN_RAN);
        this.done = this.h.getAtomicReference(DONE);
        if (str2.equals("master")) {
            this.numTimesPretrainRan.set(0);
            this.numTimesPretrain.set(1);
            this.isPretrain.set(true);
            this.done.set(false);
        }
    }

    private Config hazelcast() {
        Config config = new Config();
        config.getNetworkConfig().setPort(this.hazelCastPort);
        config.getNetworkConfig().setPortAutoIncrement(false);
        config.setProperty("hazelcast.initial.min.cluster.size", "1");
        config.setProperty("hazelcast.shutdownhook.enabled", "false");
        JoinConfig join = config.getNetworkConfig().getJoin();
        join.getMulticastConfig().setEnabled(true);
        join.getAwsConfig().setEnabled(false);
        join.getMulticastConfig().setEnabled(true);
        ListConfig listConfig = new ListConfig();
        listConfig.setName(JOBS);
        config.addListConfig(listConfig);
        ListConfig listConfig2 = new ListConfig();
        listConfig2.setName("topics");
        config.addListConfig(listConfig2);
        ListConfig listConfig3 = new ListConfig();
        listConfig3.setName(AVAILABLE_WORKERS);
        config.addListConfig(listConfig3);
        MapConfig mapConfig = new MapConfig();
        mapConfig.setName(HEART_BEAT);
        config.addMapConfig(mapConfig);
        return config;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public boolean addJobToCurrent(Job job) throws Exception {
        IAtomicReference atomicReference = this.h.getAtomicReference("job-" + job.getWorkerId());
        if (atomicReference.get() != null) {
            log.info("Currently locked unable to add job for current worker");
            return false;
        }
        this.jobs.add(job);
        atomicReference.set(job);
        return true;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public List<Job> currentJobs() throws Exception {
        return new ArrayList((Collection) this.jobs);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void clearJob(Job job) throws Exception {
        this.h.getAtomicReference("job-" + job.getWorkerId()).destroy();
        this.jobs.remove(job);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void shutdown() {
        if (this.h != null) {
            this.h.shutdown();
        }
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void addTopic(String str) throws Exception {
        this.topics.add(str);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public List<String> topics() throws Exception {
        return this.topics;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public UpdateableImpl getCurrent() throws Exception {
        return ((UpdateableImpl) this.master.get()).clone();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void setCurrent(UpdateableImpl updateableImpl) throws Exception {
        this.master.set(updateableImpl);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void jobDone(Job job) {
        try {
            clearJob(job);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public boolean isPretrain() {
        return ((Boolean) this.isPretrain.get()).booleanValue() && ((Integer) this.numTimesPretrainRan.get()).intValue() < runPreTrainIterations();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void moveToFinetune() {
        this.isPretrain.set(false);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public Job jobFor(String str) {
        return (Job) this.h.getAtomicReference("job-" + str).get();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void availableForWork(String str) {
        if (this.workers.contains(str)) {
            return;
        }
        this.workers.add(str);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public List<String> jobIds() {
        ArrayList arrayList = new ArrayList();
        Iterator it = this.jobs.iterator();
        while (it.hasNext()) {
            arrayList.add(((Job) it.next()).getWorkerId());
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void addWorker(String str) {
        log.info("Adding worker " + str);
        this.heartbeat.put(str, Long.valueOf(System.currentTimeMillis()));
        if (this.workers.contains(str)) {
            return;
        }
        this.workers.add(str);
        log.info("Number of workers is now " + this.workers.size());
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void removeWorker(String str) {
        this.workers.remove(str);
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public List<String> workers() {
        return this.workers;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public int numWorkers() {
        return this.workers.size();
    }

    public synchronized HazelcastInstance getH() {
        return this.h;
    }

    public synchronized void setH(HazelcastInstance hazelcastInstance) {
        this.h = hazelcastInstance;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public Map<String, Long> getHeartBeats() {
        return this.heartbeat;
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void runPreTrainIterations(int i) {
        this.numTimesPretrain.set(Integer.valueOf(i));
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public int runPreTrainIterations() {
        return ((Integer) this.numTimesPretrain.get()).intValue();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public int numTimesPreTrainRun() {
        return ((Integer) this.numTimesPretrainRan.get()).intValue();
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void incrementNumTimesPreTrainRan() {
        this.numTimesPretrainRan.set(Integer.valueOf(numTimesPreTrainRun() + 1));
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public boolean isDone() {
        try {
            return ((Boolean) this.done.get()).booleanValue();
        } catch (Exception e) {
            log.warn("Hazelcast already shutdown...returning true on isDone()");
            return true;
        }
    }

    @Override // org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker
    public void finish() {
        try {
            this.done.set(true);
        } catch (Exception e) {
            log.warn("Hazelcast already shutdown...done() being set is pointless");
        }
    }
}
