package ai.djl.pytorch.engine;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.pytorch.jni.IValue;
import ai.djl.pytorch.jni.IValueUtils;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/pytorch/engine/PtSymbolBlock.class */
public class PtSymbolBlock extends AbstractSymbolBlock implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(PtSymbolBlock.class);
    private AtomicReference<Long> handle;
    private String uid;
    private PtNDManager manager;
    private boolean isTrain;
    private PairList<String, Shape> inputDescriptions;
    private PairList<String, Shape> outputDescriptions;
    private boolean first;
    private Map<String, Parameter> parameters;

    public PtSymbolBlock(PtNDManager ptNDManager, long j) {
        this(ptNDManager);
        this.handle = new AtomicReference<>(Long.valueOf(j));
        this.uid = String.valueOf(j);
        ptNDManager.attachInternal(this.uid, this);
    }

    public PtSymbolBlock(PtNDManager ptNDManager) {
        this.manager = ptNDManager;
        this.isTrain = true;
        this.first = true;
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        Long andSet = this.handle.getAndSet(null);
        if (andSet != null) {
            JniUtils.deleteModule(andSet.longValue());
            this.manager.detachInternal(this.uid);
            this.manager = null;
        }
    }

    public IValue forward(IValue... iValueArr) {
        return IValueUtils.forward(this, iValueArr);
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        if (this.isTrain != z) {
            this.isTrain = z;
            if (this.isTrain) {
                JniUtils.enableTrainingMode(this);
            } else {
                JniUtils.enableInferenceMode(this);
            }
        }
        if (System.getProperty("ai.djl.pytorch.graph_optimizer") != null) {
            JniUtils.setGraphExecutorOptimize(Boolean.getBoolean("ai.djl.pytorch.graph_optimizer"));
        }
        if (this.first) {
            synchronized (PtSymbolBlock.class) {
                if (this.first) {
                    this.inputDescriptions = new PairList<>();
                    this.outputDescriptions = new PairList<>();
                    Iterator<NDArray> it = nDList.iterator();
                    while (it.hasNext()) {
                        NDArray next = it.next();
                        this.inputDescriptions.add(next.getName(), next.getShape());
                    }
                    NDList forward = IValueUtils.forward(this, nDList, z);
                    Iterator<NDArray> it2 = forward.iterator();
                    while (it2.hasNext()) {
                        NDArray next2 = it2.next();
                        this.outputDescriptions.add(next2.getName(), next2.getShape());
                    }
                    this.first = false;
                    return forward;
                }
            }
        }
        return IValueUtils.forward(this, nDList, z);
    }

    @Override // ai.djl.nn.AbstractBaseBlock, ai.djl.nn.Block
    public PairList<String, Shape> describeInput() {
        if (this.inputDescriptions == null) {
            logger.warn("Input shapes are unknown, please run predict or forward once and call describeInput again.");
        }
        return this.inputDescriptions;
    }

    @Override // ai.djl.nn.Block
    public ParameterList getDirectParameters() {
        if (this.parameters == null) {
            NDList moduleGetParams = JniUtils.moduleGetParams(this, this.manager);
            this.parameters = new LinkedHashMap(moduleGetParams.size());
            Iterator<NDArray> it = moduleGetParams.iterator();
            while (it.hasNext()) {
                NDArray next = it.next();
                this.parameters.put(next.getName(), Parameter.builder().setName(next.getName()).setType(inferType(next.getName())).optArray(next).build());
            }
        }
        return new ParameterList(this.parameters);
    }

    private static Parameter.Type inferType(String str) {
        return str.contains("bias") ? Parameter.Type.BIAS : str.contains("gamma") ? Parameter.Type.GAMMA : str.contains("beta") ? Parameter.Type.BETA : (str.contains("moving_mean") || str.contains("running_mean")) ? Parameter.Type.RUNNING_MEAN : (str.contains("moving_var") || str.contains("running_var")) ? Parameter.Type.RUNNING_VAR : str.contains("weight") ? Parameter.Type.WEIGHT : Parameter.Type.OTHER;
    }

    @Override // ai.djl.nn.SymbolBlock
    public PairList<String, Shape> describeOutput() {
        if (this.outputDescriptions == null) {
            logger.warn("Output shapes are unknown, please run predict or forward once and call describeOutput again.");
        }
        return this.outputDescriptions;
    }

    @Override // ai.djl.nn.AbstractSymbolBlock, ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        NDManager newBaseManager = NDManager.newBaseManager();
        try {
            NDList nDList = new NDList();
            for (Shape shape : shapeArr) {
                nDList.add(newBaseManager.ones(shape));
            }
            Shape[] shapeArr2 = (Shape[]) forwardInternal(new ParameterStore(newBaseManager, false), nDList, false, (PairList<String, Object>) null).stream().map((v0) -> {
                return v0.getShape();
            }).toArray(i -> {
                return new Shape[i];
            });
            if (newBaseManager != null) {
                newBaseManager.close();
            }
            return shapeArr2;
        } catch (Throwable th) {
            if (newBaseManager != null) {
                try {
                    newBaseManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr, DataType[] dataTypeArr) {
        NDManager newBaseManager = NDManager.newBaseManager();
        try {
            NDList nDList = new NDList();
            for (int i = 0; i < shapeArr.length; i++) {
                nDList.add(newBaseManager.ones(shapeArr[i], dataTypeArr == null ? DataType.FLOAT32 : dataTypeArr[i]));
            }
            Shape[] shapeArr2 = (Shape[]) forwardInternal(new ParameterStore(newBaseManager, false), nDList, false, (PairList<String, Object>) null).stream().map((v0) -> {
                return v0.getShape();
            }).toArray(i2 -> {
                return new Shape[i2];
            });
            if (newBaseManager != null) {
                newBaseManager.close();
            }
            return shapeArr2;
        } catch (Throwable th) {
            if (newBaseManager != null) {
                try {
                    newBaseManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.nn.AbstractBaseBlock, ai.djl.nn.Block
    public void saveParameters(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeByte(this.version);
        JniUtils.writeModule(this, dataOutputStream, true);
    }

    @Override // ai.djl.nn.AbstractBaseBlock, ai.djl.nn.Block
    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        byte readByte = dataInputStream.readByte();
        if (readByte != this.version) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) readByte));
        }
        long loadModuleHandle = JniUtils.loadModuleHandle(dataInputStream, nDManager.getDevice(), true, true);
        this.handle = new AtomicReference<>(Long.valueOf(loadModuleHandle));
        this.uid = String.valueOf(loadModuleHandle);
        nDManager.attachInternal(this.uid, this);
    }

    public Long getHandle() {
        Long l = this.handle.get();
        if (l == null) {
            throw new IllegalStateException("PyTorch model handle has been released!");
        }
        return l;
    }
}
