package com.streamscape.recasepunc;

import com.streamscape.recasepunc.RecasepuncPredictor;
import com.streamscape.recasepunc.RecasepuncTokenizer;
import com.streamscape.text.service.sttext.PunctuationParams;
import com.streamscape.text.service.sttext.STText;
import com.streamscape.text.service.sttext.STTextPunctuator;
import com.streamscape.text.service.sttext.STTextTimePunctuator;
import java.io.File;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/* loaded from: input_file:com/streamscape/recasepunc/RecasepuncPunctuator.class */
public class RecasepuncPunctuator implements STTextPunctuator {
    private final RecasepuncTokenizer tokenizer;
    private PunctuationParams punctuationParams;
    private RecasepuncPredictor predictor;
    private RecasepuncTokenizer.Token lastToken = null;
    private int lastPunctuationLabelIndex = 2;
    private final List<RecasepuncTokenizer.Token> collectedTokens = new ArrayList();
    private static final int BATCH_SIZE = 256;

    public RecasepuncPunctuator(RecasepuncPredictor recasepuncPredictor, PunctuationParams punctuationParams) {
        this.predictor = recasepuncPredictor;
        this.tokenizer = new RecasepuncTokenizer(recasepuncPredictor.getVocabulary());
        this.punctuationParams = punctuationParams;
    }

    @Override // com.streamscape.text.service.sttext.STTextPunctuator
    public PunctuationParams getPunctuationParams() {
        return this.punctuationParams;
    }

    @Override // com.streamscape.text.service.sttext.STTextPunctuator
    public STTextPunctuator.PunctuatedText punctuate(List<STText.TranscribeExplainedWord> list, boolean z) {
        List list2 = (List) list.stream().map(transcribeExplainedWord -> {
            return new RecasepuncTokenizer.Token(transcribeExplainedWord.getWord(), this.tokenizer.getTokenId(transcribeExplainedWord.getWord()), transcribeExplainedWord.getStart(), transcribeExplainedWord.getEnd());
        }).collect(Collectors.toList());
        this.collectedTokens.addAll(list2);
        Supplier supplier = () -> {
            return Boolean.valueOf(this.collectedTokens.size() > 0 && (this.collectedTokens.get(this.collectedTokens.size() - 1).end - this.collectedTokens.get(0).start) * 1000.0d >= ((double) this.punctuationParams.getCollectionTimeMs()));
        };
        STTextPunctuator.PunctuatedTextBuilder punctuatedTextBuilder = new STTextPunctuator.PunctuatedTextBuilder();
        if (this.collectedTokens.size() >= 256 || ((Boolean) supplier.get()).booleanValue() || z) {
            if (this.lastToken == null) {
                list2.add(0, new RecasepuncTokenizer.Token(this.predictor.getVocabulary().getToken(RecasepuncVocab.CLS_TOKEN_ID), RecasepuncVocab.CLS_TOKEN_ID));
            }
            while (this.collectedTokens.size() > 256) {
                punctuateTokens(this.collectedTokens.subList(0, 256), punctuatedTextBuilder, z);
                for (int i = 0; i < 256; i++) {
                    this.collectedTokens.remove(0);
                }
            }
            if (((Boolean) supplier.get()).booleanValue() || z) {
                punctuateTokens(this.collectedTokens, punctuatedTextBuilder, z);
                this.collectedTokens.clear();
            }
        }
        if (punctuatedTextBuilder.getTextLength() == 0 && z && this.lastToken != null && this.lastPunctuationLabelIndex < 2) {
            punctuatedTextBuilder.appendText(".");
        }
        return punctuatedTextBuilder.toPunctuatedText();
    }

    public STTextPunctuator.PunctuatedText punctuate(String str) {
        List<RecasepuncTokenizer.Token> list = this.tokenizer.tokenizeToTokens(str);
        list.add(0, new RecasepuncTokenizer.Token(this.predictor.getVocabulary().getToken(RecasepuncVocab.CLS_TOKEN_ID), RecasepuncVocab.CLS_TOKEN_ID));
        STTextPunctuator.PunctuatedTextBuilder punctuatedTextBuilder = new STTextPunctuator.PunctuatedTextBuilder();
        punctuateTokens(list, punctuatedTextBuilder, true);
        return punctuatedTextBuilder.toPunctuatedText();
    }

    public void punctuateTokens(List<RecasepuncTokenizer.Token> list, STTextPunctuator.PunctuatedTextBuilder punctuatedTextBuilder, boolean z) {
        int i = 0;
        while (i < list.size()) {
            List list2 = (List) list.subList(i, Math.min(i + 256, list.size())).stream().collect(Collectors.toList());
            int[] array = list2.stream().mapToInt(token -> {
                return token.tokenId;
            }).toArray();
            if (array.length > 256) {
                array = Arrays.copyOf(array, 256);
            }
            i += array.length;
            RecasepuncPredictor.PunctuationPredictions predict = this.predictor.predict(array);
            for (int i2 = 0; i2 < list2.size(); i2++) {
                RecasepuncTokenizer.Token token2 = (RecasepuncTokenizer.Token) list2.get(i2);
                int i3 = (int) predict.punctuations[i2];
                int i4 = (int) predict.cases[i2];
                if (token2.tokenId != 101 && token2.tokenId != 102) {
                    if (this.lastPunctuationLabelIndex > 1 && (i4 == 0 || i4 == 2)) {
                        i4 = 2;
                    }
                    String str = token2.token;
                    if (i4 == 2 || i4 == 1) {
                        str = STTextTimePunctuator.uppercaseWordFirstLetter(str);
                    }
                    punctuatedTextBuilder.appendText(str);
                    punctuatedTextBuilder.appendText(RecasepuncPredictor.punctuationSymbols[i3]);
                    punctuatedTextBuilder.incrementWordsCount(1);
                    if (token2.start != 0.0d && punctuatedTextBuilder.getStartMs() == 0) {
                        punctuatedTextBuilder.setStartMs((long) (token2.start * 1000.0d));
                    }
                    if (token2.end != 0.0d) {
                        punctuatedTextBuilder.setEndMs((long) (token2.end * 1000.0d));
                    }
                    if (!str.equals("'") && (i2 >= list2.size() - 1 || !((RecasepuncTokenizer.Token) list2.get(i2 + 1)).token.equals("'"))) {
                        punctuatedTextBuilder.appendText(" ");
                    }
                    if (this.lastToken != null && this.punctuationParams.getNewLineDelayMs() > 0 && token2.start - this.lastToken.end > this.punctuationParams.getNewLineDelayMs() / 1000.0d && i3 > 1) {
                        punctuatedTextBuilder.appendText("\n\n");
                    }
                    this.lastPunctuationLabelIndex = i3;
                    this.lastToken = token2;
                }
            }
        }
        if (!z || this.lastPunctuationLabelIndex >= 1) {
            return;
        }
        punctuatedTextBuilder.appendText(".");
    }

    public static void main(String[] strArr) throws Exception {
        RecasepuncPredictor recasepuncPredictor = new RecasepuncPredictor("/Users/nkutuzov/Streamscape/git/NeeveBuild/sttext/src/main/resources/recasepunc/checkpoint_traced_en_00.22.pt", "/Users/nkutuzov/Streamscape/git/NeeveBuild/sttext/src/main/resources/recasepunc/recasepunc_vocab.json");
        recasepuncPredictor.initialize();
        System.out.println(new RecasepuncPunctuator(recasepuncPredictor, PunctuationParams.DEFAULT_BERT).punctuate(new String(Files.readAllBytes(new File("/Users/nkutuzov/Streamscape/mnodes/Sysplex1/TestNode15_vosk/audiofiles/recording001.webm.transcript.txt").toPath()))).getText());
    }
}
