package org.elasticsearch.xpack.ml.inference.nlp;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.nlp.NlpHelpers;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.class */
public class FillMaskProcessor extends NlpTask.Processor {
    /* JADX INFO: Access modifiers changed from: package-private */
    public FillMaskProcessor(NlpTokenizer nlpTokenizer) {
        super(nlpTokenizer);
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.NlpTask.Processor
    public void validateInputs(List<String> list) {
        ValidationException validationException = new ValidationException();
        if (list.isEmpty()) {
            validationException.addValidationError("input request is empty");
        }
        String maskToken = this.tokenizer.getMaskToken();
        for (String str : list) {
            int indexOf = str.indexOf(maskToken);
            if (indexOf < 0) {
                validationException.addValidationError("no " + maskToken + " token could be found in the input");
            }
            if (str.indexOf(maskToken, indexOf + maskToken.length()) > 0) {
                throw ExceptionsHelper.badRequestException("only one {} token should exist in the input", new Object[]{maskToken});
            }
        }
        if (!validationException.validationErrors().isEmpty()) {
            throw validationException;
        }
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.NlpTask.Processor
    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig nlpConfig) {
        return this.tokenizer.requestBuilder();
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.NlpTask.Processor
    public NlpTask.ResultProcessor getResultProcessor(NlpConfig nlpConfig) {
        if (!(nlpConfig instanceof FillMaskConfig)) {
            return (tokenizationResult, pyTorchInferenceResult, z) -> {
                return processResult(tokenizationResult, pyTorchInferenceResult, this.tokenizer, 5, "predicted_value", z);
            };
        }
        FillMaskConfig fillMaskConfig = (FillMaskConfig) nlpConfig;
        return (tokenizationResult2, pyTorchInferenceResult2, z2) -> {
            return processResult(tokenizationResult2, pyTorchInferenceResult2, this.tokenizer, fillMaskConfig.getNumTopClasses(), fillMaskConfig.getResultsField(), z2);
        };
    }

    static InferenceResults processResult(TokenizationResult tokenizationResult, PyTorchInferenceResult pyTorchInferenceResult, NlpTokenizer nlpTokenizer, int i, String str, boolean z) {
        if (tokenizationResult.isEmpty()) {
            throw new ElasticsearchStatusException("tokenization is empty", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }
        if (z) {
            throw chunkingNotSupportedException(TaskType.NER);
        }
        if (nlpTokenizer.getMaskTokenId().isEmpty()) {
            throw ExceptionsHelper.conflictStatusException("The token id for the mask token {} is not known in the tokenizer. Check the vocabulary contains the mask token", new Object[]{nlpTokenizer.getMaskToken()});
        }
        int asInt = nlpTokenizer.getMaskTokenId().getAsInt();
        OptionalInt tokenIndex = tokenizationResult.getTokenization(0).getTokenIndex(asInt);
        if (tokenIndex.isEmpty()) {
            throw new ElasticsearchStatusException("mask token id [{}] not found in the tokenization", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{Integer.valueOf(asInt)});
        }
        NlpHelpers.ScoreAndIndex[] pKVar = NlpHelpers.topK(i == -1 ? Integer.MAX_VALUE : Math.max(i, 1), NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchInferenceResult.getInferenceResult()[0][tokenIndex.getAsInt()]));
        ArrayList arrayList = new ArrayList(pKVar.length);
        if (i != 0) {
            for (NlpHelpers.ScoreAndIndex scoreAndIndex : pKVar) {
                arrayList.add(new TopClassEntry(tokenizationResult.decode(tokenizationResult.getFromVocab(scoreAndIndex.index)), scoreAndIndex.score, scoreAndIndex.score));
            }
        }
        String decode = tokenizationResult.decode(tokenizationResult.getFromVocab(pKVar[0].index));
        return new FillMaskResults(decode, tokenizationResult.getTokenization(0).input().get(0).replace(nlpTokenizer.getMaskToken(), decode), arrayList, (String) Optional.ofNullable(str).orElse("predicted_value"), Double.valueOf(pKVar[0].score), tokenizationResult.anyTruncated());
    }
}
