package org.elasticsearch.xpack.ml.inference.nlp;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.class */
public class TextExpansionProcessor extends NlpTask.Processor {
    private final NlpTask.RequestBuilder requestBuilder;
    private Map<Integer, String> replacementVocab;

    public TextExpansionProcessor(NlpTokenizer nlpTokenizer) {
        super(nlpTokenizer);
        this.requestBuilder = nlpTokenizer.requestBuilder();
        this.replacementVocab = buildSanitizedVocabMap(nlpTokenizer.getVocabulary());
    }

    static Map<Integer, String> buildSanitizedVocabMap(List<String> list) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < list.size(); i++) {
            if (list.get(i).contains(".")) {
                hashMap.put(Integer.valueOf(i), list.get(i).replace(".", "__"));
            }
        }
        return hashMap;
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.NlpTask.Processor
    public void validateInputs(List<String> list) {
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.NlpTask.Processor
    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig nlpConfig) {
        return this.requestBuilder;
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.NlpTask.Processor
    public NlpTask.ResultProcessor getResultProcessor(NlpConfig nlpConfig) {
        return (tokenizationResult, pyTorchInferenceResult, z) -> {
            return processResult(tokenizationResult, pyTorchInferenceResult, this.replacementVocab, nlpConfig.getResultsField(), z);
        };
    }

    static InferenceResults processResult(TokenizationResult tokenizationResult, PyTorchInferenceResult pyTorchInferenceResult, Map<Integer, String> map, String str, boolean z) {
        if (!z) {
            List<TextExpansionResults.WeightedToken> sparseVectorToTokenWeights = sparseVectorToTokenWeights(pyTorchInferenceResult.getInferenceResult()[0][0], tokenizationResult, map);
            sparseVectorToTokenWeights.sort((weightedToken, weightedToken2) -> {
                return Float.compare(weightedToken2.weight(), weightedToken.weight());
            });
            return new TextExpansionResults((String) Optional.ofNullable(str).orElse("predicted_value"), sparseVectorToTokenWeights, tokenizationResult.anyTruncated());
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < pyTorchInferenceResult.getInferenceResult()[0].length; i++) {
            String substring = tokenizationResult.getTokenization(i).input().get(0).substring(tokenizationResult.getTokenization(i).tokens().get(0).get(0).startOffset(), tokenizationResult.getTokenization(i).tokens().get(0).get(tokenizationResult.getTokenization(i).tokens().get(0).size() - 1).endOffset());
            List<TextExpansionResults.WeightedToken> sparseVectorToTokenWeights2 = sparseVectorToTokenWeights(pyTorchInferenceResult.getInferenceResult()[0][i], tokenizationResult, map);
            sparseVectorToTokenWeights2.sort((weightedToken3, weightedToken4) -> {
                return Float.compare(weightedToken4.weight(), weightedToken3.weight());
            });
            arrayList.add(new ChunkedTextExpansionResults.ChunkedResult(substring, sparseVectorToTokenWeights2));
        }
        return new ChunkedTextExpansionResults((String) Optional.ofNullable(str).orElse("predicted_value"), arrayList, tokenizationResult.anyTruncated());
    }

    static List<TextExpansionResults.WeightedToken> sparseVectorToTokenWeights(double[] dArr, TokenizationResult tokenizationResult, Map<Integer, String> map) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] > 0.0d) {
                arrayList.add(new TextExpansionResults.WeightedToken(tokenForId(i, tokenizationResult, map), (float) dArr[i]));
            }
        }
        return arrayList;
    }

    static String tokenForId(int i, TokenizationResult tokenizationResult, Map<Integer, String> map) {
        String str = map.get(Integer.valueOf(i));
        if (str == null) {
            str = tokenizationResult.getFromVocab(i);
        }
        return str;
    }
}
