package org.nd4j.autodiff.samediff.internal.memory;

import java.util.Collection;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.DependencyList;
import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/samediff/internal/memory/CloseValidationMemoryMgr.class */
public class CloseValidationMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) CloseValidationMemoryMgr.class);
    private final SameDiff sd;
    private final SessionMemMgr underlying;
    private final Map<INDArray, Boolean> released = new IdentityHashMap();

    public CloseValidationMemoryMgr(SameDiff sameDiff, SessionMemMgr sessionMemMgr) {
        this.sd = sameDiff;
        this.underlying = sessionMemMgr;
    }

    @Override // org.nd4j.autodiff.samediff.internal.SessionMemMgr
    public INDArray allocate(boolean z, DataType dataType, long... jArr) {
        INDArray allocate = this.underlying.allocate(z, dataType, jArr);
        this.released.put(allocate, false);
        return allocate;
    }

    @Override // org.nd4j.autodiff.samediff.internal.SessionMemMgr
    public INDArray allocate(boolean z, LongShapeDescriptor longShapeDescriptor) {
        INDArray allocate = this.underlying.allocate(z, longShapeDescriptor);
        this.released.put(allocate, false);
        return allocate;
    }

    @Override // org.nd4j.autodiff.samediff.internal.SessionMemMgr
    public void release(INDArray iNDArray) {
        Preconditions.checkState(this.released.containsKey(iNDArray), "Attempting to release an array that was not allocated by this memory manager: id=%s", iNDArray.getId());
        if (this.released.get(iNDArray).booleanValue()) {
            IdentityDependencyTracker<INDArray, InferenceSession.Dep> arrayUseTracker = this.sd.getSessions().get(Long.valueOf(Thread.currentThread().getId())).getArrayUseTracker();
            DependencyList<INDArray, InferenceSession.Dep> dependencies = arrayUseTracker.getDependencies(iNDArray);
            System.out.println(dependencies);
            if (dependencies.getDependencies() != null) {
                for (InferenceSession.Dep dep : dependencies.getDependencies()) {
                    System.out.println(dep + ": " + arrayUseTracker.isSatisfied(dep));
                }
            }
            if (dependencies.getOrDependencies() != null) {
                for (Pair<InferenceSession.Dep, InferenceSession.Dep> pair : dependencies.getOrDependencies()) {
                    System.out.println(pair + " - (" + arrayUseTracker.isSatisfied(pair.getFirst()) + "," + arrayUseTracker.isSatisfied(pair.getSecond()));
                }
            }
        }
        Preconditions.checkState(!this.released.get(iNDArray).booleanValue(), "Attempting to release an array that was already deallocated by an earlier release call to this memory manager: id=%s", iNDArray.getId());
        log.trace("Released array: id = {}", Long.valueOf(iNDArray.getId()));
        this.released.put(iNDArray, true);
    }

    @Override // org.nd4j.autodiff.samediff.internal.SessionMemMgr, java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        this.underlying.close();
    }

    public void assertAllReleasedExcept(@NonNull Collection<INDArray> collection) {
        if (collection == null) {
            throw new NullPointerException("except is marked @NonNull but is null");
        }
        Set<INDArray> set = null;
        for (INDArray iNDArray : collection) {
            if (!this.released.containsKey(iNDArray)) {
                if (set == null) {
                    set = identitySetAllConstPhVar();
                }
                if (!set.contains(iNDArray)) {
                    throw new IllegalStateException("Array " + iNDArray.getId() + " was not originally allocated by the memory manager");
                }
            } else if (this.released.get(iNDArray).booleanValue()) {
                throw new IllegalStateException("Specified output array (id=" + iNDArray.getId() + ") should not have been deallocated but was");
            }
        }
        Set newSetFromMap = Collections.newSetFromMap(new IdentityHashMap());
        newSetFromMap.addAll(collection);
        int i = 0;
        Set newSetFromMap2 = Collections.newSetFromMap(new IdentityHashMap());
        IdentityDependencyTracker<INDArray, InferenceSession.Dep> arrayUseTracker = this.sd.getSessions().get(Long.valueOf(Thread.currentThread().getId())).getArrayUseTracker();
        for (Map.Entry<INDArray, Boolean> entry : this.released.entrySet()) {
            INDArray key = entry.getKey();
            if (!newSetFromMap.contains(key) && !entry.getValue().booleanValue()) {
                newSetFromMap2.add(key);
                i++;
                log.info("Not released: array id {}", Long.valueOf(key.getId()));
                DependencyList<INDArray, InferenceSession.Dep> dependencies = arrayUseTracker.getDependencies(key);
                List<InferenceSession.Dep> dependencies2 = dependencies.getDependencies();
                List<Pair<InferenceSession.Dep, InferenceSession.Dep>> orDependencies = dependencies.getOrDependencies();
                if (dependencies2 != null) {
                    for (InferenceSession.Dep dep : dependencies2) {
                        if (!arrayUseTracker.isSatisfied(dep)) {
                            log.info("  Not satisfied: {}", dep);
                        }
                    }
                }
                if (orDependencies != null) {
                    for (Pair<InferenceSession.Dep, InferenceSession.Dep> pair : orDependencies) {
                        if (!arrayUseTracker.isSatisfied(pair.getFirst()) && !arrayUseTracker.isSatisfied(pair.getSecond())) {
                            log.info("   Not satisfied: {}", pair);
                        }
                    }
                }
            }
        }
        if (i > 0) {
            System.out.println(this.sd.summary());
            throw new IllegalStateException(i + " arrays were not released but should have been");
        }
    }

    protected Set<INDArray> identitySetAllConstPhVar() {
        Set<INDArray> newSetFromMap = Collections.newSetFromMap(new IdentityHashMap());
        for (SDVariable sDVariable : this.sd.variables()) {
            if (sDVariable.getVariableType() == VariableType.VARIABLE || sDVariable.getVariableType() == VariableType.CONSTANT || sDVariable.getVariableType() == VariableType.PLACEHOLDER) {
                newSetFromMap.add(sDVariable.getArr());
            }
        }
        return newSetFromMap;
    }
}
