package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDUtils;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.recurrent.RNN;
import ai.djl.pytorch.jni.JniUtils;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:ai/djl/pytorch/engine/PtNDArrayEx.class */
public class PtNDArrayEx implements NDArrayEx {
    private PtNDArray array;

    /* JADX INFO: Access modifiers changed from: package-private */
    public PtNDArrayEx(PtNDArray ptNDArray) {
        this.array = ptNDArray;
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray rdiv(Number number) {
        return rdiv(this.array.getManager().create(number));
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray rdiv(NDArray nDArray) {
        return (PtNDArray) nDArray.div(this.array);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray rdivi(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray rdivi(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray rsub(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray rsub(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray rsubi(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray rsubi(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray rmod(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray rmod(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray rmodi(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray rmodi(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray rpow(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray rpowi(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray relu() {
        return JniUtils.relu(this.array);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray sigmoid() {
        return JniUtils.sigmoid(this.array);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray tanh() {
        return JniUtils.tanh(this.array);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray softPlus() {
        return JniUtils.softPlus(this.array);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray softSign() {
        return JniUtils.softSign(this.array);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray leakyRelu(float f) {
        return JniUtils.leakyRelu(this.array, f);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray elu(float f) {
        return JniUtils.elu(this.array, f);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray selu() {
        return JniUtils.selu(this.array);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray gelu() {
        return JniUtils.gelu(this.array);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray maxPool(Shape shape, Shape shape2, Shape shape3, boolean z) {
        return JniUtils.maxPool(this.array, shape, shape2, shape3, z);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray globalMaxPool() {
        PtNDArray adaptiveMaxPool = JniUtils.adaptiveMaxPool(this.array, getPoolShape(this.array));
        try {
            PtNDArray ptNDArray = (PtNDArray) adaptiveMaxPool.reshape(this.array.getShape().slice(0, 2));
            if (adaptiveMaxPool != null) {
                adaptiveMaxPool.close();
            }
            return ptNDArray;
        } catch (Throwable th) {
            if (adaptiveMaxPool != null) {
                try {
                    adaptiveMaxPool.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray avgPool(Shape shape, Shape shape2, Shape shape3, boolean z, boolean z2) {
        return JniUtils.avgPool(this.array, shape, shape2, shape3, z, z2);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray globalAvgPool() {
        PtNDArray adaptiveAvgPool = JniUtils.adaptiveAvgPool(this.array, getPoolShape(this.array));
        try {
            PtNDArray ptNDArray = (PtNDArray) adaptiveAvgPool.reshape(this.array.getShape().slice(0, 2));
            if (adaptiveAvgPool != null) {
                adaptiveAvgPool.close();
            }
            return ptNDArray;
        } catch (Throwable th) {
            if (adaptiveAvgPool != null) {
                try {
                    adaptiveAvgPool.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray lpPool(float f, Shape shape, Shape shape2, Shape shape3, boolean z) {
        if (shape3.size() != 0) {
            throw new IllegalArgumentException("padding is not supported for PyTorch engine");
        }
        return JniUtils.lpPool(this.array, f, shape, shape2, z);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray globalLpPool(float f) {
        PtNDArray lpPool = JniUtils.lpPool(this.array, f, this.array.getShape().slice(2), getPoolShape(this.array), false);
        try {
            PtNDArray ptNDArray = (PtNDArray) lpPool.reshape(this.array.getShape().slice(0, 2));
            if (lpPool != null) {
                lpPool.close();
            }
            return ptNDArray;
        } catch (Throwable th) {
            if (lpPool != null) {
                try {
                    lpPool.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public void adadeltaUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5) {
        throw new UnsupportedOperationException("AdaDelta optimzier is not supported for PyTorch engine!");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public void adagradUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public void adamUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5, float f6, float f7, float f8, boolean z, boolean z2) {
        PtNDManager manager = this.array.getManager();
        JniUtils.adamUpdate(manager.from(nDList.get(0)), manager.from(nDList.get(1)), manager.from(nDList.get(2)), manager.from(nDList.get(3)), f, f2, f3, f4, f5, f6, f7, f8, z2);
        JniUtils.zeroGrad(manager.from(nDList2.singletonOrThrow()));
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public void nagUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public void rmspropUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5, float f6, float f7, boolean z) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public void sgdUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5, boolean z) {
        PtNDManager manager = this.array.getManager();
        JniUtils.sgdUpdate(manager.from(nDList.get(0)), manager.from(nDList.get(1)), f5 == 0.0f ? null : manager.from(nDList.get(2)), f, f2, f3, f4, f5);
        JniUtils.zeroGrad(manager.from(nDList2.singletonOrThrow()));
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDList convolution(NDArray nDArray, NDArray nDArray2, NDArray nDArray3, Shape shape, Shape shape2, Shape shape3, int i) {
        PtNDManager manager = this.array.getManager();
        return new NDList(JniUtils.convolution(manager.from(nDArray), manager.from(nDArray2), manager.from(nDArray3), shape, shape2, shape3, i));
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDList deconvolution(NDArray nDArray, NDArray nDArray2, NDArray nDArray3, Shape shape, Shape shape2, Shape shape3, Shape shape4, int i) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDList linear(NDArray nDArray, NDArray nDArray2, NDArray nDArray3) {
        PtNDManager manager = this.array.getManager();
        return new NDList(JniUtils.linear(manager.from(nDArray), manager.from(nDArray2), manager.from(nDArray3)));
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDList embedding(NDArray nDArray, NDArray nDArray2, SparseFormat sparseFormat) {
        if (!sparseFormat.equals(SparseFormat.DENSE) && !sparseFormat.equals(SparseFormat.COO)) {
            throw new IllegalArgumentException("PyTorch only supports COO");
        }
        PtNDManager manager = this.array.getManager();
        return new NDList(JniUtils.embedding(manager.from(nDArray), manager.from(nDArray2), sparseFormat.equals(SparseFormat.COO)));
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDList prelu(NDArray nDArray, NDArray nDArray2) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDList dropout(NDArray nDArray, float f, boolean z) {
        return new NDList(JniUtils.dropout(this.array.getManager().from(nDArray), f, z));
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDList layerNorm(NDArray nDArray, Shape shape, NDArray nDArray2, NDArray nDArray3, float f) {
        PtNDManager manager = this.array.getManager();
        return new NDList(JniUtils.layerNorm(manager.from(nDArray), shape, manager.from(nDArray2), manager.from(nDArray3), f));
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDList batchNorm(NDArray nDArray, NDArray nDArray2, NDArray nDArray3, NDArray nDArray4, NDArray nDArray5, int i, float f, float f2, boolean z) {
        PtNDManager manager = this.array.getManager();
        if (i == -1) {
            return new NDList(JniUtils.batchNorm(manager.from(nDArray), manager.from(nDArray2), manager.from(nDArray3), manager.from(nDArray4), manager.from(nDArray5), z, 1.0f - f, f2));
        }
        NDManager newSubManager = nDArray.getManager().newSubManager();
        try {
            nDArray.attach(newSubManager);
            NDArray swapAxes = JniUtils.batchNorm(manager.from(nDArray.swapAxes(1, i)), manager.from(nDArray2), manager.from(nDArray3), manager.from(nDArray4), manager.from(nDArray5), z, 1.0f - f, f2).swapAxes(1, i);
            nDArray.attach(newSubManager.getParentManager());
            swapAxes.attach(newSubManager.getParentManager());
            NDList nDList = new NDList(swapAxes);
            if (newSubManager != null) {
                newSubManager.close();
            }
            return nDList;
        } catch (Throwable th) {
            if (newSubManager != null) {
                try {
                    newSubManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDList rnn(NDArray nDArray, NDArray nDArray2, NDList nDList, boolean z, int i, RNN.Activation activation, double d, boolean z2, boolean z3, boolean z4) {
        PtNDManager manager = this.array.getManager();
        return JniUtils.rnn(manager.from(nDArray), manager.from(nDArray2), nDList, z, i, activation, d, z2, z3, z4);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDList gru(NDArray nDArray, NDArray nDArray2, NDList nDList, boolean z, int i, double d, boolean z2, boolean z3, boolean z4) {
        PtNDManager manager = this.array.getManager();
        return JniUtils.gru(manager.from(nDArray), manager.from(nDArray2), nDList, z, i, d, z2, z3, z4);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDList lstm(NDArray nDArray, NDList nDList, NDList nDList2, boolean z, int i, double d, boolean z2, boolean z3, boolean z4) {
        return JniUtils.lstm(this.array.getManager().from(nDArray), nDList, nDList2, z, i, d, z2, z3, z4);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray resize(int i, int i2, int i3) {
        NDManager newSubManager = this.array.getManager().newSubManager();
        try {
            this.array.attach(newSubManager);
            PtNDArray ptNDArray = this.array;
            if (ptNDArray.isEmpty()) {
                throw new IllegalArgumentException("attempt to resize of an empty NDArray");
            }
            if (ptNDArray.getDataType() != DataType.FLOAT32) {
                ptNDArray = ptNDArray.toType(DataType.FLOAT32, true);
            }
            int dimension = ptNDArray.getShape().dimension();
            if (dimension == 3) {
                ptNDArray = ptNDArray.expandDims(0);
            }
            PtNDArray transpose = JniUtils.interpolate(this.array.getManager().from(ptNDArray.transpose(0, 3, 1, 2)), new long[]{i2, i}, getInterpolationMode(i3), false).transpose(0, 2, 3, 1);
            if (dimension == 3) {
                transpose = transpose.squeeze(0);
            }
            this.array.attach(newSubManager.getParentManager());
            transpose.attach(newSubManager.getParentManager());
            PtNDArray ptNDArray2 = transpose;
            if (newSubManager != null) {
                newSubManager.close();
            }
            return ptNDArray2;
        } catch (Throwable th) {
            if (newSubManager != null) {
                try {
                    newSubManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDArray randomFlipLeftRight() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDArray randomFlipTopBottom() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDArray randomBrightness(float f) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDArray randomHue(float f) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDArray randomColorJitter(float f, float f2, float f3, float f4) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDArrayIndexer getIndexer(NDManager nDManager) {
        return new PtNDArrayIndexer((PtNDManager) nDManager);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray where(NDArray nDArray, NDArray nDArray2) {
        if (!nDArray.getShape().equals(this.array.getShape())) {
            throw new UnsupportedOperationException("condition and self shape mismatch, broadcast is not supported");
        }
        PtNDManager manager = this.array.getManager();
        return JniUtils.where(manager.from(nDArray), this.array, manager.from(nDArray2));
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray stack(NDList nDList, int i) {
        PtNDArray[] ptNDArrayArr = new PtNDArray[nDList.size() + 1];
        ptNDArrayArr[0] = this.array;
        int i2 = 1;
        PtNDManager manager = this.array.getManager();
        Iterator<NDArray> it = nDList.iterator();
        while (it.hasNext()) {
            int i3 = i2;
            i2++;
            ptNDArrayArr[i3] = manager.from(it.next());
        }
        return JniUtils.stack(ptNDArrayArr, i);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray concat(NDList nDList, int i) {
        NDUtils.checkConcatInput(nDList);
        PtNDArray[] ptNDArrayArr = new PtNDArray[nDList.size() + 1];
        ptNDArrayArr[0] = this.array;
        int i2 = 1;
        PtNDManager manager = this.array.getManager();
        Iterator<NDArray> it = nDList.iterator();
        while (it.hasNext()) {
            int i3 = i2;
            i2++;
            ptNDArrayArr[i3] = manager.from(it.next());
        }
        return JniUtils.cat(ptNDArrayArr, i);
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDList multiBoxTarget(NDList nDList, float f, float f2, float f3, float f4, int i) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDList multiBoxPrior(List<Float> list, List<Float> list2, List<Float> list3, List<Float> list4, boolean z) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public NDList multiBoxDetection(NDList nDList, boolean z, float f, int i, float f2, boolean z2, int i2) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.internal.NDArrayEx
    public PtNDArray getArray() {
        return this.array;
    }

    private Shape getPoolShape(NDArray nDArray) {
        switch (nDArray.getShape().dimension() - 2) {
            case 1:
                return new Shape(1);
            case 2:
                return new Shape(1, 1);
            case 3:
                return new Shape(1, 1, 1);
            default:
                throw new IllegalArgumentException("the input dimension should be in [3, 5]");
        }
    }

    private int getInterpolationMode(int i) {
        switch (i) {
            case 0:
                return 0;
            case 1:
                return 2;
            case 2:
                return 5;
            case 3:
                return 3;
            default:
                throw new UnsupportedOperationException("The kind of interpolation is not supported.");
        }
    }
}
