package org.nd4j.autodiff.samediff.array;

import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.DeviceLocalNDArray;

/* loaded from: input_file:org/nd4j/autodiff/samediff/array/ThreadSafeArrayHolder.class */
public class ThreadSafeArrayHolder implements ArrayHolder {
    private final Map<String, DeviceLocalNDArray> map = new ConcurrentHashMap();
    private final boolean lazyInit;

    public ThreadSafeArrayHolder(boolean z) {
        this.lazyInit = z;
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public boolean hasArray(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        return this.map.containsKey(str);
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public INDArray getArray(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        return this.map.get(str).get();
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public void setArray(@NonNull String str, @NonNull INDArray iNDArray) {
        if (str == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (iNDArray == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        if (iNDArray.isView()) {
            iNDArray = iNDArray.dup();
        }
        if (this.map.containsKey(str)) {
            this.map.get(str).update(iNDArray);
        } else {
            this.map.put(str, new DeviceLocalNDArray(iNDArray, this.lazyInit));
        }
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public INDArray removeArray(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        DeviceLocalNDArray remove = this.map.remove(str);
        if (remove == null) {
            return null;
        }
        return remove.get();
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public int size() {
        return this.map.size();
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public void initFrom(ArrayHolder arrayHolder) {
        this.map.clear();
        for (String str : arrayHolder.arrayNames()) {
            setArray(str, arrayHolder.getArray(str));
        }
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public Collection<String> arrayNames() {
        return Collections.unmodifiableCollection(this.map.keySet());
    }

    @Override // org.nd4j.autodiff.samediff.ArrayHolder
    public void rename(@NonNull String str, @NonNull String str2) {
        if (str == null) {
            throw new NullPointerException("from is marked non-null but is null");
        }
        if (str2 == null) {
            throw new NullPointerException("to is marked non-null but is null");
        }
        this.map.put(str2, this.map.remove(str));
    }
}
