package org.elasticsearch.xpack.ml.inference.ltr;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.script.GeneralScriptException;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.script.TemplateScript;
import org.elasticsearch.script.mustache.MustacheInvalidParameterException;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearningToRankFeatureExtractorBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/ltr/LearningToRankService.class */
public class LearningToRankService {
    private static final Map<String, String> SCRIPT_OPTIONS = Map.ofEntries(Map.entry("detect_missing_params", Boolean.TRUE.toString()));
    private final ModelLoadingService modelLoadingService;
    private final TrainedModelProvider trainedModelProvider;
    private final ScriptService scriptService;
    private final XContentParserConfiguration parserConfiguration;

    public LearningToRankService(ModelLoadingService modelLoadingService, TrainedModelProvider trainedModelProvider, ScriptService scriptService, NamedXContentRegistry namedXContentRegistry) {
        this(modelLoadingService, trainedModelProvider, scriptService, XContentParserConfiguration.EMPTY.withRegistry(namedXContentRegistry));
    }

    LearningToRankService(ModelLoadingService modelLoadingService, TrainedModelProvider trainedModelProvider, ScriptService scriptService, XContentParserConfiguration xContentParserConfiguration) {
        this.modelLoadingService = modelLoadingService;
        this.scriptService = scriptService;
        this.trainedModelProvider = trainedModelProvider;
        this.parserConfiguration = xContentParserConfiguration;
    }

    public void loadLocalModel(String str, ActionListener<LocalModel> actionListener) {
        this.modelLoadingService.getModelForLearningToRank(str, actionListener);
    }

    public void loadLearningToRankConfig(String str, Map<String, Object> map, ActionListener<LearningToRankConfig> actionListener) {
        TrainedModelProvider trainedModelProvider = this.trainedModelProvider;
        GetTrainedModelsAction.Includes all = GetTrainedModelsAction.Includes.all();
        CheckedConsumer checkedConsumer = trainedModelConfig -> {
            InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig();
            if (inferenceConfig instanceof LearningToRankConfig) {
                actionListener.onResponse(applyParams((LearningToRankConfig) inferenceConfig, (Map<String, Object>) map));
            } else {
                actionListener.onFailure(ExceptionsHelper.badRequestException(Messages.getMessage("Inference config of type [{0}] is invalid, must be of type [{1}]", new Object[]{Optional.ofNullable(trainedModelConfig.getInferenceConfig()).map((v0) -> {
                    return v0.getName();
                }).orElse("null"), LearningToRankConfig.NAME.getPreferredName()}), new Object[0]));
            }
        };
        Objects.requireNonNull(actionListener);
        trainedModelProvider.getTrainedModel(str, all, null, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    private LearningToRankConfig applyParams(LearningToRankConfig learningToRankConfig, Map<String, Object> map) throws Exception {
        if (!this.scriptService.isLangSupported("mustache")) {
            return learningToRankConfig;
        }
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap((Map) Objects.requireNonNullElse(map, Map.of()));
        XContentHelper.mergeDefaults(hashMap, learningToRankConfig.getParamsDefaults());
        Iterator it = learningToRankConfig.getFeatureExtractorBuilders().iterator();
        while (it.hasNext()) {
            arrayList.add(applyParams((LearningToRankFeatureExtractorBuilder) it.next(), hashMap));
        }
        return LearningToRankConfig.builder(learningToRankConfig).setLearningToRankFeatureExtractorBuilders(arrayList).build();
    }

    private LearningToRankFeatureExtractorBuilder applyParams(LearningToRankFeatureExtractorBuilder learningToRankFeatureExtractorBuilder, Map<String, Object> map) throws Exception {
        if (learningToRankFeatureExtractorBuilder instanceof QueryExtractorBuilder) {
            learningToRankFeatureExtractorBuilder = applyParams((QueryExtractorBuilder) learningToRankFeatureExtractorBuilder, map);
        }
        learningToRankFeatureExtractorBuilder.validate();
        return learningToRankFeatureExtractorBuilder;
    }

    private QueryExtractorBuilder applyParams(QueryExtractorBuilder queryExtractorBuilder, Map<String, Object> map) throws IOException {
        String templateSource = templateSource(queryExtractorBuilder.query());
        if (!templateSource.contains("{{")) {
            return queryExtractorBuilder;
        }
        try {
            XContentParser createParser = XContentType.JSON.xContent().createParser(this.parserConfiguration, ((TemplateScript.Factory) this.scriptService.compile(new Script(ScriptType.INLINE, "mustache", templateSource, SCRIPT_OPTIONS, Collections.emptyMap()), TemplateScript.CONTEXT)).newInstance(map).execute());
            try {
                QueryExtractorBuilder queryExtractorBuilder2 = new QueryExtractorBuilder(queryExtractorBuilder.featureName(), QueryProvider.fromXContent(createParser, false, "Inference config query is not parsable"), queryExtractorBuilder.defaultScore());
                if (createParser != null) {
                    createParser.close();
                }
                return queryExtractorBuilder2;
            } finally {
            }
        } catch (GeneralScriptException e) {
            if (e.getRootCause().getClass().getName().equals(MustacheInvalidParameterException.class.getName())) {
                return new QueryExtractorBuilder(queryExtractorBuilder.featureName(), defaultQuery(queryExtractorBuilder.defaultScore()), queryExtractorBuilder.defaultScore());
            }
            throw e;
        }
    }

    private String templateSource(QueryProvider queryProvider) throws IOException {
        XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
        try {
            String utf8ToString = BytesReference.bytes(queryProvider.toXContent(builder, ToXContent.EMPTY_PARAMS)).utf8ToString();
            if (builder != null) {
                builder.close();
            }
            return utf8ToString;
        } catch (Throwable th) {
            if (builder != null) {
                try {
                    builder.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private QueryProvider defaultQuery(float f) throws IOException {
        return QueryProvider.fromParsedQuery(f == 0.0f ? new MatchNoneQueryBuilder() : new MatchAllQueryBuilder().boost(f));
    }
}
