package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.OptionalInt;
import java.util.stream.Collectors;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertJapaneseTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RobertaTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.XLMRobertaTokenization;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.aggs.frequentitemsets.FrequentItemSetsAggregationBuilder;
import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.class */
public abstract class NlpTokenizer implements Releasable {
    public static final int CALC_DEFAULT_SPAN_VALUE = -2;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer$1, reason: invalid class name */
    /* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$elasticsearch$xpack$core$ml$inference$trainedmodel$Tokenization$Truncate = new int[Tokenization.Truncate.values().length];

        static {
            try {
                $SwitchMap$org$elasticsearch$xpack$core$ml$inference$trainedmodel$Tokenization$Truncate[Tokenization.Truncate.FIRST.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$elasticsearch$xpack$core$ml$inference$trainedmodel$Tokenization$Truncate[Tokenization.Truncate.SECOND.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$elasticsearch$xpack$core$ml$inference$trainedmodel$Tokenization$Truncate[Tokenization.Truncate.NONE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer$InnerTokenization.class */
    public static final class InnerTokenization extends Record {
        private final List<? extends DelimitedToken.Encoded> tokens;
        private final List<Integer> tokenPositionMap;

        public InnerTokenization(List<? extends DelimitedToken.Encoded> list, List<Integer> list2) {
            this.tokens = list;
            this.tokenPositionMap = list2;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, InnerTokenization.class), InnerTokenization.class, "tokens;tokenPositionMap", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer$InnerTokenization;->tokens:Ljava/util/List;", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer$InnerTokenization;->tokenPositionMap:Ljava/util/List;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, InnerTokenization.class), InnerTokenization.class, "tokens;tokenPositionMap", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer$InnerTokenization;->tokens:Ljava/util/List;", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer$InnerTokenization;->tokenPositionMap:Ljava/util/List;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, InnerTokenization.class, Object.class), InnerTokenization.class, "tokens;tokenPositionMap", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer$InnerTokenization;->tokens:Ljava/util/List;", "FIELD:Lorg/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer$InnerTokenization;->tokenPositionMap:Ljava/util/List;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public List<? extends DelimitedToken.Encoded> tokens() {
            return this.tokens;
        }

        public List<Integer> tokenPositionMap() {
            return this.tokenPositionMap;
        }
    }

    abstract int clsTokenId();

    abstract int sepTokenId();

    abstract int maxSequenceLength();

    abstract boolean isWithSpecialTokens();

    abstract int numExtraTokensForSingleSequence();

    abstract int getNumExtraTokensForSeqPair();

    abstract int defaultSpanForChunking(int i);

    public abstract TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokens> list);

    public final List<TokenizationResult.Tokens> tokenize(String str, Tokenization.Truncate truncate, int i, int i2, Integer num) {
        if (num == null) {
            num = Integer.valueOf(maxSequenceLength());
        }
        InnerTokenization innerTokenize = innerTokenize(str);
        List<? extends DelimitedToken.Encoded> list = innerTokenize.tokens();
        List<Integer> list2 = innerTokenize.tokenPositionMap();
        int size = isWithSpecialTokens() ? list.size() + numExtraTokensForSingleSequence() : list.size();
        boolean z = false;
        if (size > num.intValue()) {
            switch (AnonymousClass1.$SwitchMap$org$elasticsearch$xpack$core$ml$inference$trainedmodel$Tokenization$Truncate[truncate.ordinal()]) {
                case FrequentItemSetsAggregationBuilder.DEFAULT_MINIMUM_SET_SIZE /* 1 */:
                case 2:
                    z = true;
                    list = list.subList(0, isWithSpecialTokens() ? num.intValue() - numExtraTokensForSingleSequence() : num.intValue());
                    list2 = list2.subList(0, isWithSpecialTokens() ? num.intValue() - numExtraTokensForSingleSequence() : num.intValue());
                    break;
                case DeploymentManager.NUM_RESTART_ATTEMPTS /* 3 */:
                    if (i == -1) {
                        throw ExceptionsHelper.badRequestException("Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]", new Object[]{Integer.valueOf(size), num});
                    }
                    break;
            }
        }
        if (size <= num.intValue() || i == -1) {
            return List.of(createTokensBuilder(clsTokenId(), sepTokenId(), isWithSpecialTokens()).addSequence((List) list.stream().map((v0) -> {
                return v0.getEncoding();
            }).collect(Collectors.toList()), list2).build(str, z, innerTokenize.tokens, -1, i2));
        }
        if (i == -2) {
            i = defaultSpanForChunking(num.intValue());
        }
        ArrayList arrayList = new ArrayList();
        int i3 = 0;
        int i4 = 0;
        int i5 = -1;
        while (i3 < list.size()) {
            i3 = Math.min(i4 + (isWithSpecialTokens() ? num.intValue() - numExtraTokensForSingleSequence() : num.intValue()), list.size());
            if (i3 != list.size()) {
                while (i3 > i4 + 1 && Objects.equals(list2.get(i3), list2.get(i3 - 1))) {
                    i3--;
                }
            }
            arrayList.add(createTokensBuilder(clsTokenId(), sepTokenId(), isWithSpecialTokens()).addSequence((List) list.subList(i4, i3).stream().map((v0) -> {
                return v0.getEncoding();
            }).collect(Collectors.toList()), list2.subList(i4, i3)).build(str, false, (List<? extends DelimitedToken>) list.subList(i4, i3), i5, i2));
            i5 = i;
            int i6 = i4;
            i4 = i3 - i;
            if (i4 < list.size()) {
                while (i4 > i6 + 1 && Objects.equals(list2.get(i4), list2.get(i4 - 1))) {
                    i4--;
                    i5++;
                }
            }
        }
        return arrayList;
    }

    public TokenizationResult.Tokens tokenize(String str, String str2, Tokenization.Truncate truncate, int i) {
        return tokenize(str, innerTokenize(str), str2, truncate, i);
    }

    public TokenizationResult.Tokens tokenize(String str, InnerTokenization innerTokenization, String str2, Tokenization.Truncate truncate, int i) {
        List<? extends DelimitedToken.Encoded> list = innerTokenization.tokens;
        List<Integer> list2 = innerTokenization.tokenPositionMap;
        InnerTokenization innerTokenize = innerTokenize(str2);
        List<? extends DelimitedToken.Encoded> list3 = innerTokenize.tokens;
        List<Integer> list4 = innerTokenize.tokenPositionMap;
        if (!isWithSpecialTokens()) {
            throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
        }
        int numExtraTokensForSeqPair = getNumExtraTokensForSeqPair();
        int size = list.size() + list3.size() + numExtraTokensForSeqPair;
        boolean z = false;
        if (size > maxSequenceLength()) {
            switch (AnonymousClass1.$SwitchMap$org$elasticsearch$xpack$core$ml$inference$trainedmodel$Tokenization$Truncate[truncate.ordinal()]) {
                case FrequentItemSetsAggregationBuilder.DEFAULT_MINIMUM_SET_SIZE /* 1 */:
                    z = true;
                    if (list3.size() <= maxSequenceLength() - numExtraTokensForSeqPair) {
                        list = list.subList(0, (maxSequenceLength() - numExtraTokensForSeqPair) - list3.size());
                        list2 = list2.subList(0, (maxSequenceLength() - numExtraTokensForSeqPair) - list3.size());
                        break;
                    } else {
                        throw ExceptionsHelper.badRequestException("Attempting truncation [{}] but input is too large for the second sequence. The tokenized input length [{}] exceeds the maximum sequence length [{}], when taking special tokens into account", new Object[]{truncate.toString(), Integer.valueOf(list3.size()), Integer.valueOf(maxSequenceLength() - numExtraTokensForSeqPair)});
                    }
                case 2:
                    z = true;
                    if (list.size() <= maxSequenceLength() - numExtraTokensForSeqPair) {
                        list3 = list3.subList(0, (maxSequenceLength() - numExtraTokensForSeqPair) - list.size());
                        list4 = list4.subList(0, (maxSequenceLength() - numExtraTokensForSeqPair) - list.size());
                        break;
                    } else {
                        throw ExceptionsHelper.badRequestException("Attempting truncation [{}] but input is too large for the first sequence. The tokenized input length [{}] exceeds the maximum sequence length [{}], when taking special tokens into account", new Object[]{truncate.toString(), Integer.valueOf(list.size()), Integer.valueOf(maxSequenceLength() - numExtraTokensForSeqPair)});
                    }
                case DeploymentManager.NUM_RESTART_ATTEMPTS /* 3 */:
                    throw ExceptionsHelper.badRequestException("Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]", new Object[]{Integer.valueOf(size), Integer.valueOf(maxSequenceLength())});
            }
        }
        return createTokensBuilder(clsTokenId(), sepTokenId(), isWithSpecialTokens()).addSequencePair((List) list.stream().map((v0) -> {
            return v0.getEncoding();
        }).collect(Collectors.toList()), list2, (List) list3.stream().map((v0) -> {
            return v0.getEncoding();
        }).collect(Collectors.toList()), list4).build(List.of(str, str2), z, List.of(innerTokenization.tokens, innerTokenize.tokens), -1, i);
    }

    public List<TokenizationResult.Tokens> tokenize(String str, String str2, Tokenization.Truncate truncate, int i, int i2) {
        if (!isWithSpecialTokens()) {
            throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
        }
        InnerTokenization innerTokenize = innerTokenize(str);
        List<? extends DelimitedToken.Encoded> list = innerTokenize.tokens;
        List<Integer> list2 = innerTokenize.tokenPositionMap;
        InnerTokenization innerTokenize2 = innerTokenize(str2);
        List<? extends DelimitedToken.Encoded> list3 = innerTokenize2.tokens;
        List<Integer> list4 = innerTokenize2.tokenPositionMap;
        int numExtraTokensForSeqPair = getNumExtraTokensForSeqPair();
        int size = list.size() + list3.size() + numExtraTokensForSeqPair;
        boolean z = false;
        if (size > maxSequenceLength() && i < 0) {
            switch (AnonymousClass1.$SwitchMap$org$elasticsearch$xpack$core$ml$inference$trainedmodel$Tokenization$Truncate[truncate.ordinal()]) {
                case FrequentItemSetsAggregationBuilder.DEFAULT_MINIMUM_SET_SIZE /* 1 */:
                    z = true;
                    if (list3.size() <= maxSequenceLength() - numExtraTokensForSeqPair) {
                        list = list.subList(0, (maxSequenceLength() - numExtraTokensForSeqPair) - list3.size());
                        list2 = list2.subList(0, (maxSequenceLength() - numExtraTokensForSeqPair) - list3.size());
                        break;
                    } else {
                        throw ExceptionsHelper.badRequestException("Attempting truncation [{}] but input is too large for the second sequence. The tokenized input length [{}] exceeds the maximum sequence length [{}], when taking special tokens into account", new Object[]{truncate.toString(), Integer.valueOf(list3.size()), Integer.valueOf(maxSequenceLength() - numExtraTokensForSeqPair)});
                    }
                case 2:
                    z = true;
                    if (list.size() <= maxSequenceLength() - numExtraTokensForSeqPair) {
                        list3 = list3.subList(0, (maxSequenceLength() - numExtraTokensForSeqPair) - list.size());
                        list4 = list4.subList(0, (maxSequenceLength() - numExtraTokensForSeqPair) - list.size());
                        break;
                    } else {
                        throw ExceptionsHelper.badRequestException("Attempting truncation [{}] but input is too large for the first sequence. The tokenized input length [{}] exceeds the maximum sequence length [{}], when taking special tokens into account", new Object[]{truncate.toString(), Integer.valueOf(list.size()), Integer.valueOf(maxSequenceLength() - numExtraTokensForSeqPair)});
                    }
                case DeploymentManager.NUM_RESTART_ATTEMPTS /* 3 */:
                    throw ExceptionsHelper.badRequestException("Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]", new Object[]{Integer.valueOf(size), Integer.valueOf(maxSequenceLength())});
            }
        }
        if (z || size < maxSequenceLength()) {
            return List.of(createTokensBuilder(clsTokenId(), sepTokenId(), isWithSpecialTokens()).addSequencePair((List) list.stream().map((v0) -> {
                return v0.getEncoding();
            }).collect(Collectors.toList()), list2, (List) list3.stream().map((v0) -> {
                return v0.getEncoding();
            }).collect(Collectors.toList()), list4).build(List.of(str, str2), z, List.of(innerTokenize.tokens, innerTokenize2.tokens), -1, i2));
        }
        ArrayList arrayList = new ArrayList();
        int i3 = 0;
        int i4 = 0;
        int i5 = -1;
        List<Integer> list5 = (List) list.stream().map((v0) -> {
            return v0.getEncoding();
        }).collect(Collectors.toList());
        int maxSequenceLength = (maxSequenceLength() - numExtraTokensForSeqPair) - list.size();
        if (maxSequenceLength <= 0) {
            throw new IllegalArgumentException(Strings.format("Unable to do sequence pair tokenization: the first sequence [%d tokens] is longer than the max sequence length [%d tokens]", new Object[]{Integer.valueOf(list.size() + numExtraTokensForSeqPair), Integer.valueOf(maxSequenceLength())}));
        }
        if (i > maxSequenceLength) {
            throw new IllegalArgumentException(Strings.format("Unable to do sequence pair tokenization: the combined first sequence, span length and delimiting tokens [%d + %d + %d = %d tokens] is longer than the max sequence length [%d tokens]. Reduce the size of the [span] window.", new Object[]{Integer.valueOf(list.size()), Integer.valueOf(i), Integer.valueOf(numExtraTokensForSeqPair), Integer.valueOf(list.size() + i + numExtraTokensForSeqPair), Integer.valueOf(maxSequenceLength())}));
        }
        while (i3 < list3.size()) {
            i3 = Math.min(i4 + maxSequenceLength, list3.size());
            if (i3 != list3.size()) {
                while (i3 > i4 + 1 && Objects.equals(list4.get(i3), list4.get(i3 - 1))) {
                    i3--;
                }
            }
            arrayList.add(createTokensBuilder(clsTokenId(), sepTokenId(), isWithSpecialTokens()).addSequencePair(list5, list2, (List) list3.subList(i4, i3).stream().map((v0) -> {
                return v0.getEncoding();
            }).collect(Collectors.toList()), list4.subList(i4, i3)).build(List.of(str, str2), false, List.of(list, list3.subList(i4, i3)), i5, i2));
            i5 = i;
            int i6 = i4;
            i4 = i3 - i;
            if (i4 <= i6) {
                throw new IllegalStateException("Tokenization cannot be satisfied with the current span setting. Consider decreasing the span setting");
            }
            if (i4 < list3.size()) {
                while (i4 > i6 + 1 && Objects.equals(list4.get(i4), list4.get(i4 - 1))) {
                    i4--;
                    i5++;
                }
            }
        }
        return arrayList;
    }

    public abstract NlpTask.RequestBuilder requestBuilder();

    public abstract OptionalInt getPadTokenId();

    public abstract String getPadToken();

    public abstract OptionalInt getMaskTokenId();

    public abstract String getMaskToken();

    public abstract List<String> getVocabulary();

    public int getSpan() {
        return -1;
    }

    abstract TokenizationResult.TokensBuilder createTokensBuilder(int i, int i2, boolean z);

    public abstract InnerTokenization innerTokenize(String str);

    public static NlpTokenizer build(Vocabulary vocabulary, Tokenization tokenization) throws IOException {
        ExceptionsHelper.requireNonNull(tokenization, NlpConfig.TOKENIZATION);
        ExceptionsHelper.requireNonNull(vocabulary, NlpConfig.VOCABULARY);
        if (tokenization instanceof BertTokenization) {
            return BertTokenizer.builder(vocabulary.get(), tokenization).build();
        }
        if (tokenization instanceof BertJapaneseTokenization) {
            return BertJapaneseTokenizer.builder(vocabulary.get(), tokenization).build();
        }
        if (tokenization instanceof MPNetTokenization) {
            return MPNetTokenizer.mpBuilder(vocabulary.get(), tokenization).build();
        }
        if (tokenization instanceof RobertaTokenization) {
            return RobertaTokenizer.builder(vocabulary.get(), vocabulary.merges(), (RobertaTokenization) tokenization).build();
        }
        if (!(tokenization instanceof XLMRobertaTokenization)) {
            throw new IllegalArgumentException("unknown tokenization type [" + tokenization.getName() + "]");
        }
        return XLMRobertaTokenizer.builder(vocabulary.get(), vocabulary.scores(), (XLMRobertaTokenization) tokenization).build();
    }
}
