package ai.djl.translate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.Shape;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.StringUtils;

/* loaded from: input_file:ai/djl/translate/PaddingStackBatchifier.class */
public final class PaddingStackBatchifier implements Batchifier {
    private static final long serialVersionUID = 1;
    private List<Integer> arraysToPad;
    private List<Integer> dimsToPad;
    private transient List<NDArraySupplier> paddingSuppliers;
    private List<Integer> paddingSizes;
    private boolean includeValidLengths;

    /* loaded from: input_file:ai/djl/translate/PaddingStackBatchifier$Builder.class */
    public static final class Builder {
        private List<Integer> arraysToPad;
        private List<Integer> dimsToPad;
        private List<NDArraySupplier> paddingSuppliers;
        private List<Integer> paddingSizes;
        private boolean includeValidLengths;

        private Builder() {
            this.arraysToPad = new ArrayList();
            this.dimsToPad = new ArrayList();
            this.paddingSuppliers = new ArrayList();
            this.paddingSizes = new ArrayList();
        }

        public Builder optIncludeValidLengths(boolean z) {
            this.includeValidLengths = z;
            return this;
        }

        public Builder addPad(int i, int i2, NDArraySupplier nDArraySupplier) {
            return addPad(i, i2, nDArraySupplier, -1);
        }

        public Builder addPad(int i, int i2, NDArraySupplier nDArraySupplier, int i3) {
            this.arraysToPad.add(Integer.valueOf(i));
            this.dimsToPad.add(Integer.valueOf(i2));
            this.paddingSuppliers.add(nDArraySupplier);
            this.paddingSizes.add(Integer.valueOf(i3));
            return this;
        }

        public PaddingStackBatchifier build() {
            return new PaddingStackBatchifier(this);
        }
    }

    private PaddingStackBatchifier(Builder builder) {
        this.arraysToPad = builder.arraysToPad;
        this.dimsToPad = builder.dimsToPad;
        this.paddingSuppliers = builder.paddingSuppliers;
        this.paddingSizes = builder.paddingSizes;
        this.includeValidLengths = builder.includeValidLengths;
    }

    @Override // ai.djl.translate.Batchifier
    public NDList batchify(NDList[] nDListArr) {
        NDList nDList = new NDList(nDListArr.length);
        NDManager manager = nDListArr[0].get(0).getManager();
        for (int i = 0; i < this.arraysToPad.size(); i++) {
            int intValue = this.arraysToPad.get(i).intValue();
            int intValue2 = this.dimsToPad.get(i).intValue();
            NDArray nDArray = this.paddingSuppliers.get(i).get(manager);
            long intValue3 = this.paddingSizes.get(i).intValue();
            long findMaxSize = findMaxSize(nDListArr, intValue, intValue2);
            if (intValue3 != -1 && findMaxSize > intValue3) {
                throw new IllegalArgumentException("The batchifier padding size is too small " + findMaxSize + StringUtils.SPACE + intValue3);
            }
            nDList.add(manager.create(padArrays(nDListArr, intValue, intValue2, nDArray, Math.max(findMaxSize, intValue3))));
        }
        NDList batchify = Batchifier.STACK.batchify(nDListArr);
        if (this.includeValidLengths) {
            batchify.addAll(nDList);
        }
        return batchify;
    }

    @Override // ai.djl.translate.Batchifier
    public NDList[] unbatchify(NDList nDList) {
        if (!this.includeValidLengths) {
            return Batchifier.STACK.unbatchify(nDList);
        }
        NDList nDList2 = new NDList(nDList.subList(nDList.size() - this.arraysToPad.size(), nDList.size()));
        NDList[] unbatchify = Batchifier.STACK.unbatchify(new NDList(nDList.subList(0, nDList.size() - this.arraysToPad.size())));
        for (int i = 0; i < unbatchify.length; i++) {
            NDList nDList3 = unbatchify[i];
            for (int i2 = 0; i2 < this.arraysToPad.size(); i2++) {
                long j = nDList2.get(i2).getLong(i);
                int intValue = this.arraysToPad.get(i2).intValue();
                nDList3.set(intValue, nDList3.get(intValue).get(NDIndex.sliceAxis(this.dimsToPad.get(i2).intValue() - 1, 0L, j)));
            }
        }
        return unbatchify;
    }

    @Override // ai.djl.translate.Batchifier
    public NDList[] split(NDList nDList, int i, boolean z) {
        if (!this.includeValidLengths) {
            return Batchifier.STACK.split(nDList, i, z);
        }
        NDList nDList2 = new NDList(nDList.subList(nDList.size() - this.arraysToPad.size(), nDList.size()));
        NDList nDList3 = new NDList(nDList.subList(0, nDList.size() - this.arraysToPad.size()));
        NDList[] split = Batchifier.STACK.split(nDList3, i, z);
        long j = split[0].get(0).getShape().get(0);
        long j2 = nDList3.get(0).getShape().get(0);
        for (int i2 = 0; i2 < split.length; i2++) {
            NDList nDList4 = split[i2];
            for (int i3 = 0; i3 < this.arraysToPad.size(); i3++) {
                nDList4.add(nDList2.get(i3).get(NDIndex.sliceAxis(0, i2 * j, Math.min((i2 + 1) * j, j2))));
            }
        }
        return split;
    }

    public static long findMaxSize(NDList[] nDListArr, int i, int i2) {
        long j = -1;
        for (NDList nDList : nDListArr) {
            j = Math.max(j, nDList.get(i).getShape().get(i2));
        }
        return j;
    }

    public static long[] padArrays(NDList[] nDListArr, int i, int i2, NDArray nDArray, long j) {
        NDArray broadcast;
        long[] jArr = new long[nDListArr.length];
        for (int i3 = 0; i3 < nDListArr.length; i3++) {
            NDArray nDArray2 = nDListArr[i3].get(i);
            String name = nDArray2.getName();
            long j2 = nDArray2.getShape().get(i2);
            if (j2 < j) {
                int dimension = nDArray2.getShape().dimension() - nDArray.getShape().dimension();
                if (dimension == 0) {
                    broadcast = nDArray.repeat(Shape.update(nDArray2.getShape(), i2, j - j2));
                } else {
                    if (dimension <= 0) {
                        throw new IllegalArgumentException("The padding must be <=" + dimension + " dimensions, but found " + nDArray.getShape().dimension());
                    }
                    broadcast = nDArray.broadcast(Shape.update(nDArray2.getShape(), i2, j - j2));
                }
                nDArray2 = nDArray2.concat(broadcast.toType(nDArray2.getDataType(), false), i2);
            }
            nDArray2.setName(name);
            nDListArr[i3].set(i, nDArray2);
            jArr[i3] = j2;
        }
        return jArr;
    }

    public static Builder builder() {
        return new Builder();
    }
}
