package ai.djl.ndarray;

import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import com.sun.jna.Platform;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:ai/djl/ndarray/NDSerializer.class */
public final class NDSerializer {
    private static final int VERSION = 3;
    private static final int BUFFER_SIZE = 1048576;
    private static final String MAGIC_NUMBER = "NDAR";
    private static final int ARRAY_ALIGN = 64;
    private static final byte[] NUMPY_MAGIC = {-109, 78, 85, 77, 80, 89};
    private static final Pattern PATTERN = Pattern.compile("\\{'descr': '(.+)', 'fortran_order': False, 'shape': \\((.*)\\),");

    private NDSerializer() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static byte[] encode(NDArray nDArray) {
        try {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream((Math.toIntExact(nDArray.size()) * nDArray.getDataType().getNumOfBytes()) + 100);
            try {
                encode(nDArray, byteArrayOutputStream);
                byte[] byteArray = byteArrayOutputStream.toByteArray();
                byteArrayOutputStream.close();
                return byteArray;
            } finally {
            }
        } catch (IOException e) {
            throw new AssertionError("This should never happen", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void encode(NDArray nDArray, OutputStream outputStream) throws IOException {
        DataOutputStream dataOutputStream = outputStream instanceof DataOutputStream ? (DataOutputStream) outputStream : new DataOutputStream(outputStream);
        dataOutputStream.writeUTF(MAGIC_NUMBER);
        dataOutputStream.writeInt(3);
        String name = nDArray.getName();
        if (name == null) {
            dataOutputStream.write(0);
        } else {
            dataOutputStream.write(1);
            dataOutputStream.writeUTF(name);
        }
        dataOutputStream.writeUTF(nDArray.getSparseFormat().name());
        dataOutputStream.writeUTF(nDArray.getDataType().name());
        dataOutputStream.write(nDArray.getShape().getEncoded());
        ByteBuffer byteBuffer = nDArray.toByteBuffer();
        dataOutputStream.write(byteBuffer.order() == ByteOrder.BIG_ENDIAN ? 62 : 60);
        int remaining = byteBuffer.remaining();
        dataOutputStream.writeInt(remaining);
        if (remaining > 0) {
            if (byteBuffer.hasArray() && byteBuffer.remaining() == byteBuffer.array().length) {
                dataOutputStream.write(byteBuffer.array(), byteBuffer.position(), remaining);
            } else {
                if (remaining > BUFFER_SIZE) {
                    byte[] bArr = new byte[BUFFER_SIZE];
                    while (remaining > BUFFER_SIZE) {
                        byteBuffer.get(bArr);
                        dataOutputStream.write(bArr);
                        remaining = byteBuffer.remaining();
                    }
                }
                byte[] bArr2 = new byte[remaining];
                byteBuffer.get(bArr2);
                dataOutputStream.write(bArr2);
            }
        }
        dataOutputStream.flush();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void encodeAsNumpy(NDArray nDArray, OutputStream outputStream) throws IOException {
        StringBuilder sb = new StringBuilder(80);
        sb.append("{'descr': '").append(nDArray.getDataType().asNumpy()).append("', 'fortran_order': False, 'shape': ");
        long[] shape = nDArray.getShape().getShape();
        if (shape.length == 1) {
            sb.append('(').append(shape[0]).append(",)");
        } else {
            sb.append(nDArray.getShape());
        }
        sb.append(", }");
        int length = sb.length() + 1;
        int length2 = 64 - (((NUMPY_MAGIC.length + length) + 4) % 64);
        ByteBuffer allocate = ByteBuffer.allocate(2);
        allocate.order(ByteOrder.LITTLE_ENDIAN);
        allocate.putShort((short) (length2 + length));
        outputStream.write(NUMPY_MAGIC);
        outputStream.write(1);
        outputStream.write(0);
        outputStream.write(allocate.array());
        outputStream.write(sb.toString().getBytes(StandardCharsets.US_ASCII));
        for (int i = 0; i < length2; i++) {
            outputStream.write(32);
        }
        outputStream.write(10);
        outputStream.write(nDArray.toByteArray());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static NDArray decode(NDManager nDManager, ByteBuffer byteBuffer) {
        ByteOrder nativeOrder;
        if (!MAGIC_NUMBER.equals(readUTF(byteBuffer))) {
            throw new IllegalArgumentException("Malformed NDArray data");
        }
        int i = byteBuffer.getInt();
        if (i < 1 || i > 3) {
            throw new IllegalArgumentException("Unexpected NDArray encode version " + i);
        }
        String str = null;
        if (i > 1 && byteBuffer.get() == 1) {
            str = readUTF(byteBuffer);
        }
        readUTF(byteBuffer);
        DataType valueOf = DataType.valueOf(readUTF(byteBuffer));
        Shape decode = Shape.decode(byteBuffer);
        if (i > 2) {
            nativeOrder = byteBuffer.get() == 62 ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN;
        } else {
            nativeOrder = ByteOrder.nativeOrder();
        }
        int i2 = byteBuffer.getInt();
        ByteBuffer slice = byteBuffer.slice();
        slice.limit(i2);
        slice.order(nativeOrder);
        NDArray create = nDManager.create(slice, decode, valueOf);
        create.setName(str);
        byteBuffer.position(byteBuffer.position() + i2);
        return create;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static NDArray decode(NDManager nDManager, InputStream inputStream) throws IOException {
        ByteOrder nativeOrder;
        DataInputStream dataInputStream = inputStream instanceof DataInputStream ? (DataInputStream) inputStream : new DataInputStream(inputStream);
        if (!MAGIC_NUMBER.equals(dataInputStream.readUTF())) {
            throw new IllegalArgumentException("Malformed NDArray data");
        }
        int readInt = dataInputStream.readInt();
        if (readInt < 1 || readInt > 3) {
            throw new IllegalArgumentException("Unexpected NDArray encode version " + readInt);
        }
        String str = null;
        if (readInt > 1 && dataInputStream.readByte() == 1) {
            str = dataInputStream.readUTF();
        }
        dataInputStream.readUTF();
        DataType valueOf = DataType.valueOf(dataInputStream.readUTF());
        Shape decode = Shape.decode(dataInputStream);
        if (readInt > 2) {
            nativeOrder = dataInputStream.readByte() == 62 ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN;
        } else {
            nativeOrder = ByteOrder.nativeOrder();
        }
        int readInt2 = dataInputStream.readInt();
        ByteBuffer allocateDirect = nDManager.allocateDirect(readInt2);
        allocateDirect.order(nativeOrder);
        readData(dataInputStream, allocateDirect, readInt2);
        NDArray create = nDManager.create(allocateDirect, decode, valueOf);
        create.setName(str);
        return create;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static NDArray decodeNumpy(NDManager nDManager, InputStream inputStream) throws IOException {
        DataInputStream dataInputStream = inputStream instanceof DataInputStream ? (DataInputStream) inputStream : new DataInputStream(inputStream);
        byte[] bArr = new byte[NUMPY_MAGIC.length];
        dataInputStream.readFully(bArr);
        if (!Arrays.equals(bArr, NUMPY_MAGIC)) {
            throw new IllegalArgumentException("Malformed numpy data");
        }
        byte readByte = dataInputStream.readByte();
        byte readByte2 = dataInputStream.readByte();
        if (readByte < 1 || readByte > 3 || readByte2 != 0) {
            throw new IllegalArgumentException("Unknown numpy version: " + ((int) readByte) + '.' + ((int) readByte2));
        }
        int i = readByte == 1 ? 2 : 4;
        dataInputStream.readFully(bArr, 0, i);
        ByteBuffer wrap = ByteBuffer.wrap(bArr, 0, i);
        wrap.order(ByteOrder.LITTLE_ENDIAN);
        byte[] bArr2 = new byte[readByte == 1 ? wrap.getShort() : wrap.getInt()];
        dataInputStream.readFully(bArr2);
        String trim = new String(bArr2, StandardCharsets.UTF_8).trim();
        Matcher matcher = PATTERN.matcher(trim);
        if (!matcher.find()) {
            throw new IllegalArgumentException("Invalid numpy header: " + trim);
        }
        String group = matcher.group(1);
        DataType fromNumpy = DataType.fromNumpy(group);
        String group2 = matcher.group(2);
        Shape shape = new Shape(group2.isEmpty() ? new long[0] : Arrays.stream(group2.split(", ?")).mapToLong(Long::parseLong).toArray());
        int intExact = Math.toIntExact(shape.size() * fromNumpy.getNumOfBytes());
        ByteBuffer allocateDirect = nDManager.allocateDirect(intExact);
        char charAt = group.charAt(0);
        if (charAt == '>') {
            allocateDirect.order(ByteOrder.BIG_ENDIAN);
        } else if (charAt == '<') {
            allocateDirect.order(ByteOrder.LITTLE_ENDIAN);
        }
        readData(dataInputStream, allocateDirect, intExact);
        return nDManager.create(allocateDirect, shape, fromNumpy);
    }

    private static void readData(DataInputStream dataInputStream, ByteBuffer byteBuffer, int i) throws IOException {
        if (i > 0) {
            byte[] bArr = new byte[BUFFER_SIZE];
            while (i > BUFFER_SIZE) {
                dataInputStream.readFully(bArr);
                byteBuffer.put(bArr);
                i -= BUFFER_SIZE;
            }
            dataInputStream.readFully(bArr, 0, i);
            byteBuffer.put(bArr, 0, i);
            byteBuffer.rewind();
        }
    }

    private static String readUTF(ByteBuffer byteBuffer) {
        int i;
        int i2 = (byteBuffer.get() & 65280) + (byteBuffer.get() & 255);
        byte[] bArr = new byte[i2];
        char[] cArr = new char[i2];
        byteBuffer.get(bArr);
        int i3 = 0;
        int i4 = 0;
        while (i4 < i2 && (i = bArr[i4] & 255) <= 127) {
            i4++;
            int i5 = i3;
            i3++;
            cArr[i5] = (char) i;
        }
        while (i4 < i2) {
            int i6 = bArr[i4] & 255;
            switch (i6 >> 4) {
                case 0:
                case 1:
                case 2:
                case 3:
                case 4:
                case Platform.OPENBSD /* 5 */:
                case Platform.WINDOWSCE /* 6 */:
                case Platform.AIX /* 7 */:
                    i4++;
                    int i7 = i3;
                    i3++;
                    cArr[i7] = (char) i6;
                    break;
                case 8:
                case Platform.GNU /* 9 */:
                case Platform.KFREEBSD /* 10 */:
                case Platform.NETBSD /* 11 */:
                default:
                    throw new IllegalArgumentException("malformed input around byte " + i4);
                case 12:
                case 13:
                    i4 += 2;
                    if (i4 <= i2) {
                        byte b = bArr[i4 - 1];
                        if ((b & 192) == 128) {
                            int i8 = i3;
                            i3++;
                            cArr[i8] = (char) (((i6 & 31) << 6) | (b & 63));
                            break;
                        } else {
                            throw new IllegalArgumentException("malformed input around byte " + i4);
                        }
                    } else {
                        throw new IllegalArgumentException("malformed UTF-8 input");
                    }
                case 14:
                    i4 += 3;
                    if (i4 <= i2) {
                        byte b2 = bArr[i4 - 2];
                        byte b3 = bArr[i4 - 1];
                        if ((b2 & 192) != 128 || (b3 & 192) != 128) {
                            throw new IllegalArgumentException("malformed UTF-8 input");
                        }
                        int i9 = i3;
                        i3++;
                        cArr[i9] = (char) (((i6 & 15) << 12) | ((b2 & 63) << 6) | (b3 & 63));
                        break;
                    } else {
                        throw new IllegalArgumentException("malformed UTF-8 input");
                    }
                    break;
            }
        }
        return new String(cArr, 0, i3);
    }
}
