package org.elasticsearch.xpack.ml.inference.nlp;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.class */
public class TextClassificationProcessor extends NlpTask.Processor {
    private final NlpTask.RequestBuilder requestBuilder;
    private final String[] classLabels;
    private final int numTopClasses;

    /* JADX INFO: Access modifiers changed from: package-private */
    public TextClassificationProcessor(NlpTokenizer nlpTokenizer, TextClassificationConfig textClassificationConfig) {
        super(nlpTokenizer);
        this.requestBuilder = nlpTokenizer.requestBuilder();
        this.classLabels = (String[]) textClassificationConfig.getClassificationLabels().toArray(i -> {
            return new String[i];
        });
        this.numTopClasses = textClassificationConfig.getNumTopClasses() < 0 ? this.classLabels.length : textClassificationConfig.getNumTopClasses();
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.NlpTask.Processor
    public void validateInputs(List<String> list) {
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.NlpTask.Processor
    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig nlpConfig) {
        return this.requestBuilder;
    }

    @Override // org.elasticsearch.xpack.ml.inference.nlp.NlpTask.Processor
    public NlpTask.ResultProcessor getResultProcessor(NlpConfig nlpConfig) {
        if (!(nlpConfig instanceof TextClassificationConfig)) {
            return (tokenizationResult, pyTorchInferenceResult, z) -> {
                return processResult(tokenizationResult, pyTorchInferenceResult, this.numTopClasses, Arrays.asList(this.classLabels), "predicted_value", z);
            };
        }
        TextClassificationConfig textClassificationConfig = (TextClassificationConfig) nlpConfig;
        return (tokenizationResult2, pyTorchInferenceResult2, z2) -> {
            return processResult(tokenizationResult2, pyTorchInferenceResult2, textClassificationConfig.getNumTopClasses() < 0 ? textClassificationConfig.getClassificationLabels().size() : textClassificationConfig.getNumTopClasses(), textClassificationConfig.getClassificationLabels(), textClassificationConfig.getResultsField(), z2);
        };
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static InferenceResults processResult(TokenizationResult tokenizationResult, PyTorchInferenceResult pyTorchInferenceResult, int i, List<String> list, String str, boolean z) {
        if (z) {
            throw chunkingNotSupportedException(TaskType.NER);
        }
        if (pyTorchInferenceResult.getInferenceResult().length < 1) {
            throw new ElasticsearchStatusException("Text classification result has no data", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }
        for (double[] dArr : pyTorchInferenceResult.getInferenceResult()[0]) {
            if (dArr.length != list.size()) {
                throw new ElasticsearchStatusException("Expected exactly [{}] values in text classification result; got [{}]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{Integer.valueOf(list.size()), Integer.valueOf(dArr.length)});
            }
        }
        if (tokenizationResult.getTokensBySequenceId().size() > 1) {
            throw new ElasticsearchStatusException("Unexpected batch input for text classification", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }
        double[] dArr2 = new double[list.size()];
        for (int i2 = 0; i2 < pyTorchInferenceResult.getInferenceResult()[0].length; i2++) {
            InferenceHelpers.sumDoubleArrays(dArr2, NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchInferenceResult.getInferenceResult()[0][i2]));
        }
        InferenceHelpers.divMut(dArr2, pyTorchInferenceResult.getInferenceResult()[0].length);
        int[] array = IntStream.range(0, dArr2.length).boxed().sorted(Comparator.comparing(obj -> {
            return Double.valueOf(dArr2[((Integer) obj).intValue()]);
        }).reversed()).mapToInt(num -> {
            return num.intValue();
        }).toArray();
        return new NlpClassificationInferenceResults(list.get(array[0]), (List) Arrays.stream(array).mapToObj(i3 -> {
            return new TopClassEntry(list.get(i3), dArr2[i3]);
        }).limit(i).collect(Collectors.toList()), (String) Optional.ofNullable(str).orElse("predicted_value"), Double.valueOf(dArr2[array[0]]), tokenizationResult.anyTruncated());
    }
}
