package org.elasticsearch.xpack.ml.action;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
import org.elasticsearch.action.admin.cluster.node.stats.TransportNodesStatsAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.TransportSearchAction;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.action.support.TransportAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.metrics.CounterMetric;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor;

/* loaded from: input_file:org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.class */
public class TransportGetTrainedModelsStatsAction extends TransportAction<GetTrainedModelsStatsAction.Request, GetTrainedModelsStatsAction.Response> {
    private static final Logger logger = LogManager.getLogger(TransportGetTrainedModelsStatsAction.class);
    private final Client client;
    private final ClusterService clusterService;
    private final TrainedModelProvider trainedModelProvider;
    private final Executor executor;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction$IngestStatsAccumulator.class */
    public static class IngestStatsAccumulator {
        CounterMetric ingestCount = new CounterMetric();
        CounterMetric ingestTimeInMillis = new CounterMetric();
        CounterMetric ingestCurrent = new CounterMetric();
        CounterMetric ingestFailedCount = new CounterMetric();
        String type;

        IngestStatsAccumulator() {
        }

        IngestStatsAccumulator(String str) {
            this.type = str;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void inc(IngestStats.Stats stats) {
            this.ingestCount.inc(stats.ingestCount());
            this.ingestTimeInMillis.inc(stats.ingestTimeInMillis());
            this.ingestCurrent.inc(stats.ingestCurrent());
            this.ingestFailedCount.inc(stats.ingestFailedCount());
        }

        IngestStats.Stats build() {
            return new IngestStats.Stats(this.ingestCount.count(), this.ingestTimeInMillis.count(), this.ingestCurrent.count(), this.ingestFailedCount.count());
        }
    }

    @Inject
    public TransportGetTrainedModelsStatsAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, ThreadPool threadPool, TrainedModelProvider trainedModelProvider, Client client) {
        super("cluster:monitor/xpack/ml/inference/stats/get", actionFilters, transportService.getTaskManager());
        this.client = client;
        this.clusterService = clusterService;
        this.trainedModelProvider = trainedModelProvider;
        this.executor = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME);
    }

    protected void doExecute(Task task, GetTrainedModelsStatsAction.Request request, ActionListener<GetTrainedModelsStatsAction.Response> actionListener) {
        this.executor.execute(ActionRunnable.wrap(actionListener, actionListener2 -> {
            doExecuteForked(task, request, actionListener2);
        }));
    }

    protected void doExecuteForked(Task task, GetTrainedModelsStatsAction.Request request, ActionListener<GetTrainedModelsStatsAction.Response> actionListener) {
        TaskId taskId = new TaskId(this.clusterService.localNode().getId(), task.getId());
        ModelAliasMetadata fromState = ModelAliasMetadata.fromState(this.clusterService.state());
        TrainedModelAssignmentMetadata fromState2 = TrainedModelAssignmentMetadata.fromState(this.clusterService.state());
        Set<String> matchedDeploymentIds = matchedDeploymentIds(request.getResourceId(), fromState2);
        GetTrainedModelsStatsAction.Response.Builder builder = new GetTrainedModelsStatsAction.Response.Builder();
        SubscribableListener andThen = SubscribableListener.newForked(actionListener2 -> {
            String addModelsUsedInMatchingDeployments = addModelsUsedInMatchingDeployments(request.getResourceId(), fromState2);
            logger.debug("Expanded models/deployment Ids request [{}]", addModelsUsedInMatchingDeployments);
            this.trainedModelProvider.expandIds(addModelsUsedInMatchingDeployments, request.isAllowNoResources(), request.getPageParams(), Collections.emptySet(), fromState, taskId, matchedDeploymentIds, actionListener2);
        }).andThenAccept(tuple -> {
            builder.setExpandedModelIdsWithAliases((Map) tuple.v2()).setTotalModelCount(((Long) tuple.v1()).longValue());
        }).andThen((actionListener3, r9) -> {
            ClientHelper.executeAsyncWithOrigin(this.client, "ml", TransportNodesStatsAction.TYPE, nodeStatsRequest(this.clusterService.state(), taskId), actionListener3);
        }).andThen(this.executor, (ThreadContext) null, (actionListener4, nodesStatsResponse) -> {
            Set set = (Set) builder.getExpandedModelIdsWithAliases().entrySet().stream().flatMap(entry -> {
                return Stream.concat(((Set) entry.getValue()).stream(), Stream.of((String) entry.getKey()));
            }).collect(Collectors.toSet());
            set.addAll(matchedDeploymentIds);
            builder.setIngestStatsByModelId(inferenceIngestStatsByModelId(nodesStatsResponse, fromState, InferenceProcessorInfoExtractor.pipelineIdsByResource(this.clusterService.state(), set)));
            this.trainedModelProvider.getInferenceStats((String[]) builder.getExpandedModelIdsWithAliases().keySet().toArray(new String[0]), taskId, actionListener4);
        }).andThenAccept(list -> {
            builder.setInferenceStatsByModelId((Map) list.stream().collect(Collectors.toMap((v0) -> {
                return v0.getModelId();
            }, Function.identity())));
        }).andThen(this.executor, (ThreadContext) null, (actionListener5, r11) -> {
            getDeploymentStats(this.client, request.getResourceId(), taskId, fromState2, actionListener5);
        }).andThenApply(response -> {
            builder.setDeploymentStatsByDeploymentId((Map) response.getStats().results().stream().collect(Collectors.toMap((v0) -> {
                return v0.getDeploymentId();
            }, Function.identity())));
            return Integer.valueOf(response.getStats().results().stream().mapToInt((v0) -> {
                return v0.getNumberOfAllocations();
            }).sum());
        }).andThen(this.executor, (ThreadContext) null, (actionListener6, num) -> {
            modelSizeStats(builder.getExpandedModelIdsWithAliases(), request.isAllowNoResources(), taskId, actionListener6, num.intValue());
        });
        Objects.requireNonNull(builder);
        andThen.andThenAccept(builder::setModelSizeStatsByModelId).andThenApply(r6 -> {
            return builder.build(modelToDeployments(builder.getExpandedModelIdsWithAliases().keySet(), fromState2));
        }).addListener(actionListener, this.executor, (ThreadContext) null);
    }

    static String addModelsUsedInMatchingDeployments(String str, TrainedModelAssignmentMetadata trainedModelAssignmentMetadata) {
        if (Strings.isAllOrWildcard(str)) {
            return str;
        }
        HashSet hashSet = new HashSet(Arrays.asList(ExpandedIdsMatcher.tokenizeExpression(str)));
        hashSet.addAll(modelsUsedByMatchingDeploymentId(str, trainedModelAssignmentMetadata));
        return String.join(",", hashSet);
    }

    static Map<String, Set<String>> modelToDeployments(Set<String> set, TrainedModelAssignmentMetadata trainedModelAssignmentMetadata) {
        HashMap hashMap = new HashMap();
        for (TrainedModelAssignment trainedModelAssignment : trainedModelAssignmentMetadata.allAssignments().values()) {
            if (set.contains(trainedModelAssignment.getModelId())) {
                ((Set) hashMap.computeIfAbsent(trainedModelAssignment.getModelId(), str -> {
                    return new HashSet();
                })).add(trainedModelAssignment.getDeploymentId());
            }
        }
        return hashMap;
    }

    static Set<String> matchedDeploymentIds(String str, TrainedModelAssignmentMetadata trainedModelAssignmentMetadata) {
        HashSet hashSet = new HashSet();
        ExpandedIdsMatcher.SimpleIdsMatcher simpleIdsMatcher = new ExpandedIdsMatcher.SimpleIdsMatcher(str);
        for (TrainedModelAssignment trainedModelAssignment : trainedModelAssignmentMetadata.allAssignments().values()) {
            if (simpleIdsMatcher.idMatches(trainedModelAssignment.getDeploymentId())) {
                hashSet.add(trainedModelAssignment.getDeploymentId());
            }
        }
        return hashSet;
    }

    static Set<String> modelsUsedByMatchingDeploymentId(String str, TrainedModelAssignmentMetadata trainedModelAssignmentMetadata) {
        HashSet hashSet = new HashSet();
        ExpandedIdsMatcher.SimpleIdsMatcher simpleIdsMatcher = new ExpandedIdsMatcher.SimpleIdsMatcher(str);
        for (TrainedModelAssignment trainedModelAssignment : trainedModelAssignmentMetadata.allAssignments().values()) {
            if (simpleIdsMatcher.idMatches(trainedModelAssignment.getDeploymentId())) {
                hashSet.add(trainedModelAssignment.getModelId());
            }
        }
        return hashSet;
    }

    static void getDeploymentStats(Client client, String str, TaskId taskId, TrainedModelAssignmentMetadata trainedModelAssignmentMetadata, ActionListener<GetDeploymentStatsAction.Response> actionListener) {
        ExpandedIdsMatcher.SimpleIdsMatcher simpleIdsMatcher = new ExpandedIdsMatcher.SimpleIdsMatcher(str);
        HashSet hashSet = new HashSet();
        for (TrainedModelAssignment trainedModelAssignment : trainedModelAssignmentMetadata.allAssignments().values()) {
            if (simpleIdsMatcher.idMatches(trainedModelAssignment.getDeploymentId())) {
                hashSet.add(trainedModelAssignment.getDeploymentId());
            } else if (simpleIdsMatcher.idMatches(trainedModelAssignment.getModelId())) {
                hashSet.add(trainedModelAssignment.getDeploymentId());
            }
        }
        String join = String.join(",", hashSet);
        logger.debug("Fetching stats for deployments [{}]", join);
        GetDeploymentStatsAction.Request request = new GetDeploymentStatsAction.Request(join);
        request.setParentTask(taskId);
        ClientHelper.executeAsyncWithOrigin(client, "ml", GetDeploymentStatsAction.INSTANCE, request, actionListener);
    }

    private void modelSizeStats(Map<String, Set<String>> map, boolean z, TaskId taskId, ActionListener<Map<String, TrainedModelSizeStats>> actionListener, int i) {
        CheckedConsumer checkedConsumer = list -> {
            List<String> list = list.stream().filter(trainedModelConfig -> {
                return trainedModelConfig.getModelType() == TrainedModelType.PYTORCH;
            }).map((v0) -> {
                return v0.getModelId();
            }).toList();
            CheckedConsumer checkedConsumer2 = map2 -> {
                long j;
                HashMap hashMap = new HashMap();
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    TrainedModelConfig trainedModelConfig2 = (TrainedModelConfig) it.next();
                    if (trainedModelConfig2.getModelType() == TrainedModelType.PYTORCH) {
                        long longValue = ((Long) map2.getOrDefault(trainedModelConfig2.getModelId(), 0L)).longValue();
                        boolean useNewMemoryFields = TrainedModelAssignment.useNewMemoryFields(TransportVersionUtils.getMinTransportVersion(this.clusterService.state()));
                        if (longValue > 0) {
                            j = StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(trainedModelConfig2.getModelId(), longValue, useNewMemoryFields ? trainedModelConfig2.getPerDeploymentMemoryBytes() : 0L, useNewMemoryFields ? trainedModelConfig2.getPerAllocationMemoryBytes() : 0L, i);
                        } else {
                            j = 0;
                        }
                        hashMap.put(trainedModelConfig2.getModelId(), new TrainedModelSizeStats(longValue, j));
                    } else {
                        hashMap.put(trainedModelConfig2.getModelId(), new TrainedModelSizeStats(trainedModelConfig2.getModelSize(), 0L));
                    }
                }
                actionListener.onResponse(hashMap);
            };
            Objects.requireNonNull(actionListener);
            definitionLengths(list, taskId, ActionListener.wrap(checkedConsumer2, actionListener::onFailure));
        };
        Objects.requireNonNull(actionListener);
        this.trainedModelProvider.getTrainedModels(map, GetTrainedModelsAction.Includes.empty(), z, taskId, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    private void definitionLengths(List<String> list, TaskId taskId, ActionListener<Map<String, Long>> actionListener) {
        SearchRequest request = this.client.prepareSearch(new String[]{".ml-inference-*"}).setQuery(QueryBuilders.constantScoreQuery(QueryBuilders.boolQuery().filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelDefinitionDoc.NAME)).filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), list)).filter(QueryBuilders.termQuery(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName(), 0)))).setFetchSource(false).addDocValueField(TrainedModelConfig.MODEL_ID.getPreferredName()).addDocValueField(TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName()).addSort("_index", SortOrder.DESC).request();
        request.setParentTask(taskId);
        Client client = this.client;
        ActionType actionType = TransportSearchAction.TYPE;
        CheckedConsumer checkedConsumer = searchResponse -> {
            HashMap hashMap = new HashMap();
            for (SearchHit searchHit : searchResponse.getHits().getHits()) {
                DocumentField field = searchHit.field(TrainedModelConfig.MODEL_ID.getPreferredName());
                if (field != null) {
                    Object value = field.getValue();
                    if (value instanceof String) {
                        String str = (String) value;
                        DocumentField field2 = searchHit.field(TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName());
                        if (field2 != null) {
                            Object value2 = field2.getValue();
                            if (value2 instanceof Long) {
                                hashMap.put(str, (Long) value2);
                            }
                        }
                    }
                }
            }
            actionListener.onResponse(hashMap);
        };
        Objects.requireNonNull(actionListener);
        ClientHelper.executeAsyncWithOrigin(client, "ml", actionType, request, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    static Map<String, IngestStats> inferenceIngestStatsByModelId(NodesStatsResponse nodesStatsResponse, ModelAliasMetadata modelAliasMetadata, Map<String, Set<String>> map) {
        HashMap hashMap = new HashMap();
        ((Map) map.entrySet().stream().collect(Collectors.toMap(entry -> {
            String modelId = modelAliasMetadata.getModelId((String) entry.getKey());
            return modelId == null ? (String) entry.getKey() : modelId;
        }, (v0) -> {
            return v0.getValue();
        }, Sets::union))).forEach((str, set) -> {
            hashMap.put(str, mergeStats((List) nodesStatsResponse.getNodes().stream().map(nodeStats -> {
                return ingestStatsForPipelineIds(nodeStats, set);
            }).collect(Collectors.toList())));
        });
        return hashMap;
    }

    static NodesStatsRequest nodeStatsRequest(ClusterState clusterState, TaskId taskId) {
        NodesStatsRequest addMetric = new NodesStatsRequest((String[]) clusterState.nodes().getIngestNodes().keySet().toArray(i -> {
            return new String[i];
        })).clear().addMetric(NodesStatsRequestParameters.Metric.INGEST.metricName());
        addMetric.setIncludeShardsStats(false);
        addMetric.setParentTask(taskId);
        return addMetric;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static IngestStats ingestStatsForPipelineIds(NodeStats nodeStats, Set<String> set) {
        IngestStats ingestStats = nodeStats.getIngestStats();
        HashMap hashMap = new HashMap(ingestStats.processorStats());
        hashMap.keySet().retainAll(set);
        List list = (List) ingestStats.pipelineStats().stream().filter(pipelineStat -> {
            return set.contains(pipelineStat.pipelineId());
        }).collect(Collectors.toList());
        CounterMetric counterMetric = new CounterMetric();
        CounterMetric counterMetric2 = new CounterMetric();
        CounterMetric counterMetric3 = new CounterMetric();
        CounterMetric counterMetric4 = new CounterMetric();
        list.forEach(pipelineStat2 -> {
            IngestStats.Stats stats = pipelineStat2.stats();
            counterMetric.inc(stats.ingestCount());
            counterMetric2.inc(stats.ingestTimeInMillis());
            counterMetric3.inc(stats.ingestCurrent());
            counterMetric4.inc(stats.ingestFailedCount());
        });
        return new IngestStats(new IngestStats.Stats(counterMetric.count(), counterMetric2.count(), counterMetric3.count(), counterMetric4.count()), list, hashMap);
    }

    private static IngestStats mergeStats(List<IngestStats> list) {
        LinkedHashMap newLinkedHashMapWithExpectedSize = Maps.newLinkedHashMapWithExpectedSize(list.size());
        LinkedHashMap newLinkedHashMapWithExpectedSize2 = Maps.newLinkedHashMapWithExpectedSize(list.size());
        IngestStatsAccumulator ingestStatsAccumulator = new IngestStatsAccumulator();
        list.forEach(ingestStats -> {
            ingestStats.pipelineStats().forEach(pipelineStat -> {
                ((IngestStatsAccumulator) newLinkedHashMapWithExpectedSize.computeIfAbsent(pipelineStat.pipelineId(), str -> {
                    return new IngestStatsAccumulator();
                })).inc(pipelineStat.stats());
            });
            ingestStats.processorStats().forEach((str, list2) -> {
                Map map = (Map) newLinkedHashMapWithExpectedSize2.computeIfAbsent(str, str -> {
                    return new LinkedHashMap();
                });
                list2.forEach(processorStat -> {
                    ((IngestStatsAccumulator) map.computeIfAbsent(processorStat.name(), str2 -> {
                        return new IngestStatsAccumulator(processorStat.type());
                    })).inc(processorStat.stats());
                });
            });
            ingestStatsAccumulator.inc(ingestStats.totalStats());
        });
        ArrayList arrayList = new ArrayList(newLinkedHashMapWithExpectedSize.size());
        newLinkedHashMapWithExpectedSize.forEach((str, ingestStatsAccumulator2) -> {
            arrayList.add(new IngestStats.PipelineStat(str, ingestStatsAccumulator2.build()));
        });
        LinkedHashMap newLinkedHashMapWithExpectedSize3 = Maps.newLinkedHashMapWithExpectedSize(newLinkedHashMapWithExpectedSize2.size());
        newLinkedHashMapWithExpectedSize2.forEach((str2, map) -> {
            ArrayList arrayList2 = new ArrayList(map.size());
            map.forEach((str2, ingestStatsAccumulator3) -> {
                arrayList2.add(new IngestStats.ProcessorStat(str2, ingestStatsAccumulator3.type, ingestStatsAccumulator3.build()));
            });
            newLinkedHashMapWithExpectedSize3.put(str2, arrayList2);
        });
        return new IngestStats(ingestStatsAccumulator.build(), arrayList, newLinkedHashMapWithExpectedSize3);
    }

    protected /* bridge */ /* synthetic */ void doExecute(Task task, ActionRequest actionRequest, ActionListener actionListener) {
        doExecute(task, (GetTrainedModelsStatsAction.Request) actionRequest, (ActionListener<GetTrainedModelsStatsAction.Response>) actionListener);
    }
}
