package com.streamscape.recasepunc;

import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDList;
import ai.djl.pytorch.engine.PtEngineProvider;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import com.streamscape.Trace;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.stream.Collectors;

/* loaded from: input_file:com/streamscape/recasepunc/RecasepuncPredictor.class */
public class RecasepuncPredictor {
    public static final String[] punctuationSymbols = {"", ",", ".", "?", "!"};
    public static final int PUNCTUATION_NO_INDEX = 0;
    public static final int PUNCTUATION_COMMA_INDEX = 1;
    public static final int PUNCTUATION_PERIOD_INDEX = 2;
    public static final int PUNCTUATION_QUESTION_INDEX = 3;
    public static final int PUNCTUATION_EXCLAMATION_INDEX = 4;
    public static final int CASE_LOWER_INDEX = 0;
    public static final int CASE_UPPER_INDEX = 1;
    public static final int CASE_CAPITALIZE_INDEX = 2;
    public static final int CASE_OTHER_INDEX = 2;
    private String modelLocation;
    private String vocabLocation;
    private Model model;
    private Predictor<int[], PunctuationPredictions> predictor;
    private RecasepuncVocab vocabulary;

    /* loaded from: input_file:com/streamscape/recasepunc/RecasepuncPredictor$PunctuationPredictions.class */
    public static class PunctuationPredictions {
        public long[] punctuations;
        public long[] cases;

        public PunctuationPredictions(long[] jArr, long[] jArr2) {
            this.punctuations = jArr;
            this.cases = jArr2;
        }
    }

    public RecasepuncPredictor(String str, String str2) {
        this.modelLocation = str;
        this.vocabLocation = str2;
    }

    public void initialize() {
        System.setProperty("PYTORCH_VERSION", "1.13.1");
        Engine.registerEngine(new PtEngineProvider());
        try {
            Method declaredMethod = Engine.class.getDeclaredMethod("initEngine", new Class[0]);
            declaredMethod.setAccessible(true);
            String str = (String) declaredMethod.invoke(null, new Object[0]);
            Field declaredField = Engine.class.getDeclaredField("DEFAULT_ENGINE");
            Field declaredField2 = Field.class.getDeclaredField("modifiers");
            declaredField2.setAccessible(true);
            declaredField2.setInt(declaredField, declaredField.getModifiers() & (-17));
            declaredField.setAccessible(true);
            declaredField.set(null, str);
        } catch (Exception e) {
            Trace.logException(this, e, true);
        }
        Path path = Paths.get(this.modelLocation, new String[0]);
        try {
            this.vocabulary = new RecasepuncVocab(this.vocabLocation);
            this.vocabulary.initialize();
            this.model = Model.newInstance(this.modelLocation);
            this.model.load(path);
            this.predictor = this.model.newPredictor(new Translator<int[], PunctuationPredictions>() { // from class: com.streamscape.recasepunc.RecasepuncPredictor.1
                @Override // ai.djl.translate.PreProcessor
                public NDList processInput(TranslatorContext translatorContext, int[] iArr) throws Exception {
                    return new NDList(RecasepuncPredictor.this.model.getNDManager().create(iArr));
                }

                @Override // ai.djl.translate.PostProcessor
                public PunctuationPredictions processOutput(TranslatorContext translatorContext, NDList nDList) throws Exception {
                    List list = (List) ((List) nDList.stream().map(nDArray -> {
                        return nDArray.argMax(1);
                    }).collect(Collectors.toList())).stream().map(nDArray2 -> {
                        return nDArray2.toLongArray();
                    }).collect(Collectors.toList());
                    return new PunctuationPredictions((long[]) list.get(0), (long[]) list.get(1));
                }
            });
        } catch (Exception e2) {
            close();
            throw new RecasepuncException(e2);
        }
    }

    public RecasepuncVocab getVocabulary() {
        return this.vocabulary;
    }

    public void close() {
        if (this.predictor != null) {
            this.predictor.close();
            this.predictor = null;
        }
        if (this.model != null) {
            this.model.close();
            this.model = null;
        }
    }

    public PunctuationPredictions predict(int[] iArr) {
        try {
            return this.predictor.predict(iArr);
        } catch (TranslateException e) {
            throw new RecasepuncException(e);
        }
    }
}
