package ai.djl.pytorch.engine;

import ai.djl.Device;
import ai.djl.ndarray.BaseNDManager;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.util.NativeResource;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/* loaded from: input_file:ai/djl/pytorch/engine/PtNDArray.class */
public class PtNDArray extends NativeResource<Long> implements NDArray {
    private String name;
    private Device device;
    private DataType dataType;
    private Shape shape;
    private SparseFormat sparseFormat;
    private Boolean hasGradient;
    private PtNDManager manager;
    private PtNDArrayEx ptNDArrayEx;
    private String[] strs;
    private ByteBuffer dataRef;

    public PtNDArray(PtNDManager ptNDManager, long j) {
        super(Long.valueOf(j));
        this.manager = ptNDManager;
        this.ptNDArrayEx = new PtNDArrayEx(this);
        ptNDManager.attachInternal(getUid(), this);
        NDScope.register(this);
    }

    public PtNDArray(PtNDManager ptNDManager, long j, ByteBuffer byteBuffer) {
        super(Long.valueOf(j));
        this.manager = ptNDManager;
        this.ptNDArrayEx = new PtNDArrayEx(this);
        ptNDManager.attachInternal(getUid(), this);
        this.dataRef = byteBuffer;
        NDScope.register(this);
    }

    public PtNDArray(PtNDManager ptNDManager, String[] strArr, Shape shape) {
        super(-1L);
        this.manager = ptNDManager;
        this.strs = strArr;
        this.shape = shape;
        this.dataType = DataType.STRING;
        NDScope.register(this);
    }

    @Override // ai.djl.ndarray.NDResource
    public PtNDManager getManager() {
        return this.manager;
    }

    @Override // ai.djl.ndarray.NDArray
    public String getName() {
        return this.name;
    }

    @Override // ai.djl.ndarray.NDArray
    public void setName(String str) {
        this.name = str;
    }

    @Override // ai.djl.ndarray.NDArray
    public DataType getDataType() {
        if (this.dataType == null) {
            this.dataType = JniUtils.getDataType(this);
        }
        return this.dataType;
    }

    @Override // ai.djl.ndarray.NDArray
    public Device getDevice() {
        if (this.device == null) {
            this.device = JniUtils.getDevice(this);
        }
        return this.device;
    }

    @Override // ai.djl.ndarray.NDArray
    public Shape getShape() {
        if (this.shape == null) {
            this.shape = JniUtils.getShape(this);
        }
        return this.shape;
    }

    @Override // ai.djl.ndarray.NDArray
    public SparseFormat getSparseFormat() {
        if (this.sparseFormat == null) {
            this.sparseFormat = JniUtils.getSparseFormat(this);
        }
        return this.sparseFormat;
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray toDevice(Device device, boolean z) {
        return (!device.equals(getDevice()) || z) ? JniUtils.to(this, getDataType(), device) : this;
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray toType(DataType dataType, boolean z) {
        return (!dataType.equals(getDataType()) || z) ? JniUtils.to(this, dataType, getDevice()) : this;
    }

    @Override // ai.djl.ndarray.NDArray
    public void setRequiresGradient(boolean z) {
        JniUtils.attachGradient(this, z);
        this.hasGradient = Boolean.valueOf(z);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray getGradient() {
        if (!hasGradient()) {
            throw new IllegalStateException("No gradient attached to this NDArray, please call array.setRequiresGradient() on your NDArray or block.setInitializer() on your Block");
        }
        PtNDArray gradient = JniUtils.getGradient(this);
        if (gradient == null) {
            gradient = (PtNDArray) this.manager.zeros(getShape());
        }
        return gradient;
    }

    @Override // ai.djl.ndarray.NDArray
    public boolean hasGradient() {
        if (this.hasGradient == null) {
            this.hasGradient = Boolean.valueOf(JniUtils.requiresGrad(this));
        }
        return this.hasGradient.booleanValue();
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray stopGradient() {
        return JniUtils.detachGradient(this);
    }

    @Override // ai.djl.ndarray.BytesSupplier
    public ByteBuffer toByteBuffer() {
        return JniUtils.getByteBuffer(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public String[] toStringArray(Charset charset) {
        return this.strs;
    }

    @Override // ai.djl.ndarray.NDArray
    public void set(Buffer buffer) {
        int intExact = Math.toIntExact(size());
        DataType dataType = getDataType();
        BaseNDManager.validateBuffer(buffer, dataType, intExact);
        this.dataRef = null;
        if (buffer.isDirect() && (buffer instanceof ByteBuffer)) {
            if (!getDevice().isGpu()) {
                this.dataRef = (ByteBuffer) buffer;
            }
            JniUtils.set(this, (ByteBuffer) buffer);
        } else {
            ByteBuffer allocateDirect = this.manager.allocateDirect(intExact * dataType.getNumOfBytes());
            BaseNDManager.copyBuffer(buffer, allocateDirect);
            if (!getDevice().isGpu()) {
                this.dataRef = allocateDirect;
            }
            JniUtils.set(this, allocateDirect);
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray get(NDManager nDManager, long... jArr) {
        return JniUtils.getItem(this, jArr, (PtNDManager) nDManager);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray gather(NDArray nDArray, int i) {
        if (nDArray instanceof PtNDArray) {
            return JniUtils.gather(this, (PtNDArray) nDArray, i);
        }
        throw new IllegalArgumentException("Only PtNDArray index is supported.");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray gatherNd(NDArray nDArray) {
        if (!(nDArray instanceof PtNDArray)) {
            throw new IllegalArgumentException("Only PtNDArray index is supported.");
        }
        Shape shape = nDArray.getShape();
        Shape shape2 = getShape();
        int i = (int) shape.get(0);
        if (i > shape2.dimension()) {
            throw new IllegalArgumentException("Indexing rank " + shape.get(0) + " exceeds the data rank " + shape2.dimension());
        }
        NDArray nDArray2 = nDArray.get("{}, ...", Integer.valueOf(i - 1));
        long j = 1;
        for (int i2 = i - 2; i2 > -1; i2--) {
            j *= shape2.get(i2 + 1);
            nDArray2 = nDArray2.addi(nDArray.get("{}, ...", Integer.valueOf(i2)).muli(Long.valueOf(j)));
        }
        return flatten(0, i - 1).get(nDArray2);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray take(NDManager nDManager, NDArray nDArray) {
        if (nDArray instanceof PtNDArray) {
            return JniUtils.take(this, (PtNDArray) nDArray, (PtNDManager) nDManager);
        }
        throw new IllegalArgumentException("Only PtNDArray is supported.");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray put(NDArray nDArray, NDArray nDArray2) {
        if ((nDArray instanceof PtNDArray) && (nDArray2 instanceof PtNDArray)) {
            return JniUtils.put(this, (PtNDArray) nDArray, (PtNDArray) nDArray2);
        }
        throw new IllegalArgumentException("Only PtNDArray is supported.");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray scatter(NDArray nDArray, NDArray nDArray2, int i) {
        if ((nDArray instanceof PtNDArray) && (nDArray2 instanceof PtNDArray)) {
            return JniUtils.scatter(this, (PtNDArray) nDArray, (PtNDArray) nDArray2, i);
        }
        throw new IllegalArgumentException("Only PtNDArray is supported.");
    }

    @Override // ai.djl.ndarray.NDResource
    public void attach(NDManager nDManager) {
        detach();
        this.manager = (PtNDManager) nDManager;
        nDManager.attachInternal(getUid(), this);
    }

    @Override // ai.djl.ndarray.NDResource
    public void returnResource(NDManager nDManager) {
        detach();
        this.manager = (PtNDManager) nDManager;
        nDManager.attachUncappedInternal(getUid(), this);
    }

    @Override // ai.djl.ndarray.NDResource
    public void tempAttach(NDManager nDManager) {
        PtNDManager ptNDManager = this.manager;
        detach();
        this.manager = (PtNDManager) nDManager;
        nDManager.tempAttachInternal(ptNDManager, getUid(), this);
    }

    @Override // ai.djl.ndarray.NDResource
    public void detach() {
        this.manager.detachInternal(getUid());
        this.manager = PtNDManager.getSystemManager();
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray duplicate() {
        return JniUtils.clone(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray booleanMask(NDArray nDArray, int i) {
        Shape shape = nDArray.getShape();
        if (shape.equals(getShape())) {
            return JniUtils.booleanMask(this, this.manager.from(nDArray));
        }
        if (!shape.equals(getShape().slice(i))) {
            throw new UnsupportedOperationException("Not supported for shape not broadcastable " + shape + " vs " + getShape());
        }
        PtNDArray booleanMask = JniUtils.booleanMask(this, this.manager.from(nDArray));
        try {
            Shape slice = getShape().slice(0, i);
            PtNDArray reshape = booleanMask.reshape(slice.addAll(new Shape(booleanMask.getShape().size() / slice.size())));
            if (booleanMask != null) {
                booleanMask.close();
            }
            return reshape;
        } catch (Throwable th) {
            if (booleanMask != null) {
                try {
                    booleanMask.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray sequenceMask(NDArray nDArray, float f) {
        throw new UnsupportedOperationException("Not implemented yet");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray sequenceMask(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented yet");
    }

    @Override // ai.djl.ndarray.NDArray
    public boolean contentEquals(Number number) {
        return contentEquals(this.manager.create(number));
    }

    @Override // ai.djl.ndarray.NDArray
    public boolean contentEquals(NDArray nDArray) {
        if (nDArray != null && shapeEquals(nDArray) && getDataType() == nDArray.getDataType()) {
            return JniUtils.contentEqual(this, this.manager.from(nDArray));
        }
        return false;
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray eq(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray eq = eq(create);
            if (create != null) {
                create.close();
            }
            return eq;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray eq(NDArray nDArray) {
        return JniUtils.eq(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray neq(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray neq = neq(create);
            if (create != null) {
                create.close();
            }
            return neq;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray neq(NDArray nDArray) {
        return JniUtils.neq(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray gt(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray gt = gt(create);
            if (create != null) {
                create.close();
            }
            return gt;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray gt(NDArray nDArray) {
        return JniUtils.gt(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray gte(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray gte = gte(create);
            if (create != null) {
                create.close();
            }
            return gte;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray gte(NDArray nDArray) {
        return JniUtils.gte(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray lt(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray lt = lt(create);
            if (create != null) {
                create.close();
            }
            return lt;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray lt(NDArray nDArray) {
        return JniUtils.lt(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray lte(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray lte = lte(create);
            if (create != null) {
                create.close();
            }
            return lte;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray lte(NDArray nDArray) {
        return JniUtils.lte(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray add(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray add = add(create);
            if (create != null) {
                create.close();
            }
            return add;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray add(NDArray nDArray) {
        return JniUtils.add(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray sub(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray sub = sub(create);
            if (create != null) {
                create.close();
            }
            return sub;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray sub(NDArray nDArray) {
        return JniUtils.sub(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray mul(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray mul = mul(create);
            if (create != null) {
                create.close();
            }
            return mul;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray mul(NDArray nDArray) {
        return JniUtils.mul(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray div(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray div = div(create);
            if (create != null) {
                create.close();
            }
            return div;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray div(NDArray nDArray) {
        return JniUtils.div(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray mod(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray mod = mod(create);
            if (create != null) {
                create.close();
            }
            return mod;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray mod(NDArray nDArray) {
        return JniUtils.remainder(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray pow(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray pow = pow(create);
            if (create != null) {
                create.close();
            }
            return pow;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray pow(NDArray nDArray) {
        return JniUtils.pow(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray xlogy(NDArray nDArray) {
        if (isScalar() || nDArray.isScalar()) {
            throw new IllegalArgumentException("scalar is not allowed for xlogy()");
        }
        return JniUtils.xlogy(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray addi(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray addi = addi(create);
            if (create != null) {
                create.close();
            }
            return addi;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray addi(NDArray nDArray) {
        JniUtils.addi(this, this.manager.from(nDArray));
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray subi(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray subi = subi(create);
            if (create != null) {
                create.close();
            }
            return subi;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray subi(NDArray nDArray) {
        JniUtils.subi(this, this.manager.from(nDArray));
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray muli(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray muli = muli(create);
            if (create != null) {
                create.close();
            }
            return muli;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray muli(NDArray nDArray) {
        JniUtils.muli(this, this.manager.from(nDArray));
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray divi(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray divi = divi(create);
            if (create != null) {
                create.close();
            }
            return divi;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray divi(NDArray nDArray) {
        JniUtils.divi(this, this.manager.from(nDArray));
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray modi(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray modi = modi(create);
            if (create != null) {
                create.close();
            }
            return modi;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray modi(NDArray nDArray) {
        JniUtils.remainderi(this, this.manager.from(nDArray));
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray powi(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray powi = powi(create);
            if (create != null) {
                create.close();
            }
            return powi;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray powi(NDArray nDArray) {
        JniUtils.powi(this, this.manager.from(nDArray));
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray sign() {
        return JniUtils.sign(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray signi() {
        JniUtils.signi(this);
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray maximum(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray maximum = maximum(create);
            if (create != null) {
                create.close();
            }
            return maximum;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray maximum(NDArray nDArray) {
        return JniUtils.max(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray minimum(Number number) {
        NDArray create = this.manager.create(number);
        try {
            PtNDArray minimum = minimum(create);
            if (create != null) {
                create.close();
            }
            return minimum;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray minimum(NDArray nDArray) {
        return JniUtils.min(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray all() {
        PtNDArray type = toType(DataType.BOOLEAN, true);
        try {
            PtNDArray all = JniUtils.all(type);
            if (type != null) {
                type.close();
            }
            return all;
        } catch (Throwable th) {
            if (type != null) {
                try {
                    type.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray any() {
        PtNDArray type = toType(DataType.BOOLEAN, true);
        try {
            PtNDArray any = JniUtils.any(type);
            if (type != null) {
                type.close();
            }
            return any;
        } catch (Throwable th) {
            if (type != null) {
                try {
                    type.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray none() {
        PtNDArray type = toType(DataType.BOOLEAN, true);
        try {
            PtNDArray none = JniUtils.none(type);
            if (type != null) {
                type.close();
            }
            return none;
        } catch (Throwable th) {
            if (type != null) {
                try {
                    type.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray neg() {
        return JniUtils.neg(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray negi() {
        JniUtils.negi(this);
        return this;
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray abs() {
        return JniUtils.abs(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray square() {
        return JniUtils.square(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray sqrt() {
        return JniUtils.sqrt(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray cbrt() {
        return JniUtils.pow(this, (PtNDArray) this.manager.create(0.3333333333333333d));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray floor() {
        return JniUtils.floor(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray ceil() {
        return JniUtils.ceil(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray round() {
        return JniUtils.round(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray trunc() {
        return JniUtils.trunc(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray exp() {
        return JniUtils.exp(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray gammaln() {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray log() {
        return JniUtils.log(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray log10() {
        return JniUtils.log10(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray log2() {
        return JniUtils.log2(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray sin() {
        return JniUtils.sin(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray cos() {
        return JniUtils.cos(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray tan() {
        return JniUtils.tan(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray asin() {
        return JniUtils.asin(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray acos() {
        return JniUtils.acos(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray atan() {
        return JniUtils.atan(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray sinh() {
        return JniUtils.sinh(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray cosh() {
        return JniUtils.cosh(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray tanh() {
        return JniUtils.tanh(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray asinh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray acosh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray atanh() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray toDegrees() {
        return mul((Number) Double.valueOf(180.0d)).div((Number) Double.valueOf(3.141592653589793d));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray toRadians() {
        return mul((Number) Double.valueOf(3.141592653589793d)).div((Number) Double.valueOf(180.0d));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray max() {
        return JniUtils.max(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray max(int[] iArr, boolean z) {
        if (iArr.length > 1) {
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return JniUtils.max(this, iArr[0], z);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray min() {
        return JniUtils.min(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray min(int[] iArr, boolean z) {
        if (iArr.length > 1) {
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return JniUtils.min(this, iArr[0], z);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray sum() {
        return JniUtils.sum(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray sum(int[] iArr, boolean z) {
        return JniUtils.sum(this, Arrays.stream(iArr).mapToLong(i -> {
            return i;
        }).toArray(), z);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray cumProd(int i) {
        return JniUtils.cumProd(this, i, null);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray cumProd(int i, DataType dataType) {
        return JniUtils.cumProd(this, i, dataType);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray prod() {
        return JniUtils.prod(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray prod(int[] iArr, boolean z) {
        if (iArr.length > 1) {
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return JniUtils.prod(this, iArr[0], z);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray mean() {
        return JniUtils.mean(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray mean(int[] iArr, boolean z) {
        if (iArr.length > 1) {
            throw new UnsupportedOperationException("Only 1 axis is support!");
        }
        return JniUtils.mean(this, iArr[0], z);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray normalize(double d, long j, double d2) {
        return JniUtils.normalize(this, d, j, d2);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray rotate90(int i, int[] iArr) {
        if (iArr.length != 2) {
            throw new IllegalArgumentException("Axes must be 2");
        }
        return JniUtils.rot90(this, i, iArr);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray trace(int i, int i2, int i3) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDList split(long j, int i) {
        return JniUtils.split(this, getShape().get(i) / j, i);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDList split(long[] jArr, int i) {
        if (jArr.length == 0) {
            return new NDList(this);
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(Long.valueOf(jArr[0]));
        for (int i2 = 1; i2 < jArr.length; i2++) {
            arrayList.add(Long.valueOf(jArr[i2] - jArr[i2 - 1]));
        }
        arrayList.add(Long.valueOf(size(i) - jArr[jArr.length - 1]));
        return JniUtils.split(this, arrayList.stream().mapToLong(l -> {
            return l.longValue();
        }).toArray(), i);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray flatten() {
        return JniUtils.flatten(this, 0L, -1L);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray flatten(int i, int i2) {
        return JniUtils.flatten(this, i, i2);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray fft(long j, long j2) {
        return JniUtils.fft(this, j, j2);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray stft(long j, long j2, boolean z, NDArray nDArray, boolean z2, boolean z3) {
        return JniUtils.stft(this, j, j2, (PtNDArray) nDArray, z, z2, z3);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray reshape(Shape shape) {
        return JniUtils.reshape(this, shape.getShape());
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray expandDims(int i) {
        return JniUtils.unsqueeze(this, i);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray squeeze() {
        return JniUtils.squeeze(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray squeeze(int i) {
        return JniUtils.squeeze(this, i);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray squeeze(int[] iArr) {
        if (isScalar()) {
            if (iArr.length == 0 || (iArr.length == 1 && iArr[0] == 0)) {
                return (PtNDArray) duplicate();
            }
            throw new IllegalArgumentException("axis " + iArr[0] + " is out of bounds for array of dimension 0");
        }
        long[] shape = getShape().getShape();
        ArrayList arrayList = new ArrayList();
        Set set = (Set) IntStream.of(iArr).boxed().collect(Collectors.toCollection(HashSet::new));
        for (int i : iArr) {
            if (shape[i] != 1) {
                throw new IllegalArgumentException("cannot select an axis to squeeze out which has size not equal to one");
            }
        }
        for (int i2 = 0; i2 < shape.length; i2++) {
            if (!set.contains(Integer.valueOf(i2))) {
                arrayList.add(Long.valueOf(shape[i2]));
            }
        }
        return (PtNDArray) reshape(arrayList.stream().mapToLong(l -> {
            return l.longValue();
        }).toArray());
    }

    @Override // ai.djl.ndarray.NDArray
    public NDList unique(Integer num, boolean z, boolean z2, boolean z3) {
        return JniUtils.unique(this, num, z, z2, z3);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray logicalAnd(NDArray nDArray) {
        return JniUtils.logicalAnd(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray logicalOr(NDArray nDArray) {
        return JniUtils.logicalOr(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray logicalXor(NDArray nDArray) {
        return JniUtils.logicalXor(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray logicalNot() {
        return JniUtils.logicalNot(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray argSort(int i, boolean z) {
        PtNDArray argSort = JniUtils.argSort(this, i, false);
        if (z) {
            return argSort;
        }
        PtNDArray flip = JniUtils.flip(argSort, new long[]{i});
        argSort.close();
        return flip;
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray sort() {
        return sort(-1);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray sort(int i) {
        return JniUtils.sort(this, i, false);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray softmax(int i) {
        return JniUtils.softmax(this, i, getDataType());
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray logSoftmax(int i) {
        return JniUtils.logSoftmax(this, i, getDataType());
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray cumSum() {
        return isScalar() ? (PtNDArray) reshape(1) : isEmpty() ? (PtNDArray) reshape(0) : cumSum(0);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray cumSum(int i) {
        return JniUtils.cumSum(this, i);
    }

    @Override // ai.djl.ndarray.NDArray
    public void intern(NDArray nDArray) {
        PtNDArray ptNDArray = (PtNDArray) nDArray;
        JniUtils.deleteNDArray(((Long) this.handle.getAndSet((Long) ptNDArray.handle.getAndSet(null))).longValue());
        ptNDArray.close();
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray isInfinite() {
        return JniUtils.isInf(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray isNaN() {
        return JniUtils.isNaN(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray tile(long j) {
        if (isEmpty()) {
            return (PtNDArray) duplicate();
        }
        long[] jArr = new long[isScalar() ? 1 : getShape().dimension()];
        Arrays.fill(jArr, j);
        return tile(jArr);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray tile(int i, long j) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray tile(long[] jArr) {
        return JniUtils.tile(this, jArr);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray tile(Shape shape) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray repeat(long j) {
        if (isEmpty()) {
            return (PtNDArray) duplicate();
        }
        long[] jArr = new long[isScalar() ? 1 : getShape().dimension()];
        Arrays.fill(jArr, j);
        return repeat(jArr);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray repeat(int i, long j) {
        return JniUtils.repeat(this, j, i);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray repeat(long[] jArr) {
        PtNDArray ptNDArray = this;
        for (int i = 0; i < jArr.length; i++) {
            PtNDArray ptNDArray2 = ptNDArray;
            ptNDArray = JniUtils.repeat(ptNDArray, jArr[i], i);
            if (ptNDArray2 != this) {
                ptNDArray2.close();
            }
        }
        return ptNDArray;
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray repeat(Shape shape) {
        return repeat(repeatsToMatchShape(shape));
    }

    private long[] repeatsToMatchShape(Shape shape) {
        Shape shape2 = getShape();
        int dimension = shape2.dimension();
        if (shape.dimension() > dimension) {
            throw new IllegalArgumentException("The desired shape has too many dimensions");
        }
        if (shape.dimension() < dimension) {
            shape = shape2.slice(0, dimension - shape.dimension()).addAll(shape);
        }
        long[] jArr = new long[dimension];
        for (int i = 0; i < dimension; i++) {
            if (shape2.get(i) == 0 || shape.get(i) % shape2.get(i) != 0) {
                throw new IllegalArgumentException("The desired shape is not a multiple of the original shape");
            }
            jArr[i] = Math.round(Math.ceil(shape.get(i) / shape2.get(i)));
        }
        return jArr;
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray dot(NDArray nDArray) {
        int dimension = getShape().dimension();
        if (dimension != nDArray.getShape().dimension() || dimension > 2) {
            throw new UnsupportedOperationException("Dimension mismatch or high dimensional dot operation is not supported. Please use .matMul instead.");
        }
        return JniUtils.dot(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray matMul(NDArray nDArray) {
        if (isScalar() || nDArray.isScalar()) {
            throw new IllegalArgumentException("scalar is not allowed for matMul()");
        }
        return JniUtils.matmul(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray batchMatMul(NDArray nDArray) {
        if (isScalar() || nDArray.isScalar()) {
            throw new IllegalArgumentException("scalar is not allowed for batchMatMul()");
        }
        return JniUtils.bmm(this, this.manager.from(nDArray));
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray clip(Number number, Number number2) {
        return JniUtils.clip(this, number, number2);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray swapAxes(int i, int i2) {
        return JniUtils.transpose(this, i, i2);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray flip(int... iArr) {
        return JniUtils.flip(this, Arrays.stream(iArr).mapToLong(i -> {
            return i;
        }).toArray());
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray transpose() {
        int dimension = getShape().dimension();
        return transpose(IntStream.range(0, dimension).map(i -> {
            return (dimension - i) - 1;
        }).toArray());
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray transpose(int... iArr) {
        if (!isScalar() || iArr.length <= 0) {
            return JniUtils.permute(this, Arrays.stream(iArr).mapToLong(i -> {
                return i;
            }).toArray());
        }
        throw new IllegalArgumentException("axes don't match NDArray");
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray broadcast(Shape shape) {
        return JniUtils.broadcast(this, shape);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray argMax() {
        if (isEmpty()) {
            throw new IllegalArgumentException("attempt to get argMax of an empty NDArray");
        }
        return isScalar() ? (PtNDArray) this.manager.create(0L) : JniUtils.argMax(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray argMax(int i) {
        return isScalar() ? (PtNDArray) this.manager.create(0L) : JniUtils.argMax(this, i, false);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray argMin() {
        if (isEmpty()) {
            throw new IllegalArgumentException("attempt to get argMin of an empty NDArray");
        }
        return isScalar() ? (PtNDArray) this.manager.create(0L) : JniUtils.argMin(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray argMin(int i) {
        return isScalar() ? (PtNDArray) this.manager.create(0L) : JniUtils.argMin(this, i, false);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray percentile(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray percentile(Number number, int[] iArr) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray median() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray median(int[] iArr) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray toDense() {
        return (isSparse() || JniUtils.getLayout(this) == 2) ? JniUtils.toDense(this) : (PtNDArray) duplicate();
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray toSparse(SparseFormat sparseFormat) {
        if (sparseFormat == SparseFormat.DENSE) {
            throw new IllegalArgumentException("Default type is not allowed");
        }
        if (sparseFormat != SparseFormat.COO) {
            throw new UnsupportedOperationException("Only COO sparse type supported for PyTorch");
        }
        return sparseFormat == getSparseFormat() ? (PtNDArray) duplicate() : JniUtils.toSparse(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray nonzero() {
        return JniUtils.nonZeros(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray erfinv() {
        return JniUtils.erfinv(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArray inverse() {
        return JniUtils.inverse(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray norm(boolean z) {
        return JniUtils.norm(this, 2, new int[0], z);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray norm(int i, int[] iArr, boolean z) {
        return JniUtils.norm(this, i, iArr, z);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray oneHot(int i) {
        return JniUtils.oneHot(this, i, DataType.FLOAT32);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray oneHot(int i, DataType dataType) {
        return JniUtils.oneHot(this, i, dataType);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray oneHot(int i, float f, float f2, DataType dataType) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray batchDot(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray complex() {
        return JniUtils.complex(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public NDArray real() {
        return JniUtils.real(this);
    }

    @Override // ai.djl.ndarray.NDArray
    public PtNDArrayEx getNDArrayInternal() {
        if (this.ptNDArrayEx == null) {
            throw new UnsupportedOperationException("NDArray operation is not supported for String tensor");
        }
        return this.ptNDArrayEx;
    }

    public String toString() {
        if (isReleased()) {
            return "This array is already closed";
        }
        if (getDataType() == DataType.STRING) {
            return Arrays.toString(this.strs);
        }
        if (JniUtils.getLayout(this) == 0) {
            return toDebugString();
        }
        PtNDArray dense = toDense();
        try {
            String debugString = dense.toDebugString();
            if (dense != null) {
                dense.close();
            }
            return debugString;
        } catch (Throwable th) {
            if (dense != null) {
                try {
                    dense.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public boolean equals(Object obj) {
        if (obj instanceof NDArray) {
            return contentEquals((NDArray) obj);
        }
        return false;
    }

    public int hashCode() {
        return 0;
    }

    @Override // ai.djl.util.NativeResource, java.lang.AutoCloseable, ai.djl.ndarray.NDArray, ai.djl.ndarray.NDResource
    public void close() {
        onClose();
        Long l = (Long) this.handle.getAndSet(null);
        if (l != null && l.longValue() != -1) {
            JniUtils.deleteNDArray(l.longValue());
        }
        this.manager.detachInternal(getUid());
        this.dataRef = null;
    }
}
