package org.elasticsearch.xpack.ml.inference.assignment;

import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.ClusterStateUpdateTask;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.gateway.GatewayService;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
import org.elasticsearch.xpack.core.ml.utils.TransportVersionUtils;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AllocationReducer;
import org.elasticsearch.xpack.ml.job.NodeLoad;
import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
import org.elasticsearch.xpack.ml.notifications.SystemAuditor;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.class */
public class TrainedModelAssignmentClusterService implements ClusterStateListener {
    private static final Logger logger = LogManager.getLogger(TrainedModelAssignmentClusterService.class);
    private static final TransportVersion RENAME_ALLOCATION_TO_ASSIGNMENT_TRANSPORT_VERSION = TransportVersions.V_8_3_0;
    public static final TransportVersion DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION = TransportVersions.V_8_4_0;
    private final ClusterService clusterService;
    private final ThreadPool threadPool;
    private final NodeLoadDetector nodeLoadDetector;
    private final SystemAuditor systemAuditor;
    private final NodeAvailabilityZoneMapper nodeAvailabilityZoneMapper;
    private final Client client;
    private volatile int maxMemoryPercentage;
    private volatile boolean useAuto;
    private volatile int maxOpenJobs;
    protected volatile int maxLazyMLNodes;
    protected volatile long maxMLNodeSize;
    protected volatile int allocatedProcessorsScale;

    public TrainedModelAssignmentClusterService(Settings settings, ClusterService clusterService, ThreadPool threadPool, NodeLoadDetector nodeLoadDetector, SystemAuditor systemAuditor, NodeAvailabilityZoneMapper nodeAvailabilityZoneMapper, Client client) {
        this.clusterService = (ClusterService) Objects.requireNonNull(clusterService);
        this.threadPool = (ThreadPool) Objects.requireNonNull(threadPool);
        this.nodeLoadDetector = (NodeLoadDetector) Objects.requireNonNull(nodeLoadDetector);
        this.systemAuditor = (SystemAuditor) Objects.requireNonNull(systemAuditor);
        this.nodeAvailabilityZoneMapper = (NodeAvailabilityZoneMapper) Objects.requireNonNull(nodeAvailabilityZoneMapper);
        this.maxMemoryPercentage = ((Integer) MachineLearning.MAX_MACHINE_MEMORY_PERCENT.get(settings)).intValue();
        this.useAuto = ((Boolean) MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT.get(settings)).booleanValue();
        this.maxOpenJobs = ((Integer) MachineLearning.MAX_OPEN_JOBS_PER_NODE.get(settings)).intValue();
        this.maxLazyMLNodes = ((Integer) MachineLearning.MAX_LAZY_ML_NODES.get(settings)).intValue();
        this.maxMLNodeSize = ((ByteSizeValue) MachineLearning.MAX_ML_NODE_SIZE.get(settings)).getBytes();
        this.allocatedProcessorsScale = ((Integer) MachineLearning.ALLOCATED_PROCESSORS_SCALE.get(settings)).intValue();
        this.client = client;
        if (DiscoveryNode.isMasterNode(settings)) {
            clusterService.addListener(this);
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_MACHINE_MEMORY_PERCENT, (v1) -> {
                setMaxMemoryPercentage(v1);
            });
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT, (v1) -> {
                setUseAuto(v1);
            });
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_OPEN_JOBS_PER_NODE, (v1) -> {
                setMaxOpenJobs(v1);
            });
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_LAZY_ML_NODES, (v1) -> {
                setMaxLazyMLNodes(v1);
            });
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_ML_NODE_SIZE, this::setMaxMLNodeSize);
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.ALLOCATED_PROCESSORS_SCALE, (v1) -> {
                setAllocatedProcessorsScale(v1);
            });
        }
    }

    private void setMaxMemoryPercentage(int i) {
        this.maxMemoryPercentage = i;
    }

    private void setUseAuto(boolean z) {
        this.useAuto = z;
    }

    private void setMaxOpenJobs(int i) {
        this.maxOpenJobs = i;
    }

    private void setMaxLazyMLNodes(int i) {
        this.maxLazyMLNodes = i;
    }

    private void setMaxMLNodeSize(ByteSizeValue byteSizeValue) {
        this.maxMLNodeSize = byteSizeValue.getBytes();
    }

    private void setAllocatedProcessorsScale(int i) {
        this.allocatedProcessorsScale = i;
    }

    @SuppressForbidden(reason = "legacy usage of unbatched task")
    private void submitUnbatchedTask(String str, ClusterStateUpdateTask clusterStateUpdateTask) {
        this.clusterService.submitUnbatchedStateUpdateTask(str, clusterStateUpdateTask);
    }

    public void clusterChanged(ClusterChangedEvent clusterChangedEvent) {
        if (!eventStateHasGlobalBlockStateNotRecoveredBlock(clusterChangedEvent) && clusterChangedEvent.localNodeMaster()) {
            if (eventStateMinTransportVersionIsBeforeDistributedModelAllocationTransportVersion(clusterChangedEvent)) {
                removeRoutingToRemovedOrShuttingDownNodes(clusterChangedEvent);
                return;
            }
            if (clusterChangedEvent.nodesAdded()) {
                logMlNodeHeterogeneity();
            }
            Optional<String> detectReasonToRebalanceModels = detectReasonToRebalanceModels(clusterChangedEvent);
            if (detectReasonToRebalanceModels.isPresent()) {
                rebalanceAssignments(clusterChangedEvent.state(), Optional.empty(), detectReasonToRebalanceModels.get(), ActionListener.wrap(trainedModelAssignmentMetadata -> {
                    logger.debug(() -> {
                        return Strings.format("rebalanced model assignments [%s]", new Object[]{org.elasticsearch.common.Strings.toString(trainedModelAssignmentMetadata, false, true)});
                    });
                }, exc -> {
                    logger.warn("failed to rebalance models", exc);
                }));
            }
        }
    }

    boolean eventStateMinTransportVersionIsBeforeDistributedModelAllocationTransportVersion(ClusterChangedEvent clusterChangedEvent) {
        return clusterChangedEvent.state().getMinTransportVersion().before(DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION);
    }

    boolean eventStateHasGlobalBlockStateNotRecoveredBlock(ClusterChangedEvent clusterChangedEvent) {
        return clusterChangedEvent.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK);
    }

    void logMlNodeHeterogeneity() {
        MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet(getArchitecturesSetActionListener(), this.client, this.threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME));
    }

    static ActionListener<Set<String>> getArchitecturesSetActionListener() {
        return new ActionListener<Set<String>>() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService.1
            public void onResponse(Set<String> set) {
                if (set.size() > 1) {
                    TrainedModelAssignmentClusterService.logger.warn(Strings.format("Heterogeneous platform architectures were detected among ML nodes. This will prevent the deployment of some trained models. Distinct platform architectures detected: %s", new Object[]{set}));
                }
            }

            public void onFailure(Exception exc) {
                TrainedModelAssignmentClusterService.logger.error("Failed to detect heterogeneity among ML nodes with exception: ", exc);
            }
        };
    }

    private void removeRoutingToRemovedOrShuttingDownNodes(ClusterChangedEvent clusterChangedEvent) {
        if (areAssignedNodesRemoved(clusterChangedEvent)) {
            submitUnbatchedTask("removing routing entries for removed or shutting down nodes", new ClusterStateUpdateTask() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService.2
                public ClusterState execute(ClusterState clusterState) {
                    return TrainedModelAssignmentClusterService.removeRoutingToUnassignableNodes(clusterState);
                }

                public void onFailure(Exception exc) {
                    TrainedModelAssignmentClusterService.logger.error("could not remove routing entries for removed or shutting down nodes", exc);
                }

                public void clusterStateProcessed(ClusterState clusterState, ClusterState clusterState2) {
                    TrainedModelAssignmentClusterService.logger.debug(() -> {
                        return Strings.format("updated model assignments based on node changes in the cluster; new metadata [%s]", new Object[]{org.elasticsearch.common.Strings.toString(TrainedModelAssignmentMetadata.fromState(clusterState2), false, true)});
                    });
                }
            });
        }
    }

    static boolean areAssignedNodesRemoved(ClusterChangedEvent clusterChangedEvent) {
        boolean contains = clusterChangedEvent.changedCustomMetadataSet().contains("node_shutdown");
        if (!clusterChangedEvent.nodesRemoved() && !contains) {
            return false;
        }
        HashSet hashSet = new HashSet(nodesShuttingDown(clusterChangedEvent.state()));
        Stream map = clusterChangedEvent.nodesDelta().removedNodes().stream().map((v0) -> {
            return v0.getId();
        });
        Objects.requireNonNull(hashSet);
        map.forEach((v1) -> {
            r1.add(v1);
        });
        Iterator it = TrainedModelAssignmentMetadata.fromState(clusterChangedEvent.state()).allAssignments().values().iterator();
        while (it.hasNext()) {
            if (!Sets.intersection(hashSet, ((TrainedModelAssignment) it.next()).getNodeRoutingTable().keySet()).isEmpty()) {
                return true;
            }
        }
        return false;
    }

    static ClusterState removeRoutingToUnassignableNodes(ClusterState clusterState) {
        Set set = (Set) getAssignableNodes(clusterState).stream().map((v0) -> {
            return v0.getId();
        }).collect(Collectors.toSet());
        TrainedModelAssignmentMetadata fromState = TrainedModelAssignmentMetadata.fromState(clusterState);
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState);
        Set allNodeIds = clusterState.metadata().nodeShutdowns().getAllNodeIds();
        for (TrainedModelAssignment trainedModelAssignment : fromState.allAssignments().values()) {
            Set difference = Sets.difference(trainedModelAssignment.getNodeRoutingTable().keySet(), set);
            if (!difference.isEmpty()) {
                logger.debug(() -> {
                    return Strings.format("[%s] removing routing entries to nodes %s because they have been removed or are shutting down", new Object[]{trainedModelAssignment.getDeploymentId(), difference});
                });
                builder.updateAssignment(trainedModelAssignment.getDeploymentId(), removeRoutingBuilder(difference, allNodeIds, trainedModelAssignment).calculateAndSetAssignmentState());
            }
        }
        return update(clusterState, builder);
    }

    private static TrainedModelAssignment.Builder removeRoutingBuilder(Set<String> set, Set<String> set2, TrainedModelAssignment trainedModelAssignment) {
        TrainedModelAssignment.Builder fromAssignment = TrainedModelAssignment.Builder.fromAssignment(trainedModelAssignment);
        for (String str : set) {
            RoutingInfo routingInfo = (RoutingInfo) trainedModelAssignment.getNodeRoutingTable().get(str);
            if (!set2.contains(str)) {
                logger.debug(() -> {
                    return Strings.format("[%s] Removing route for unassignable node id [%s]", new Object[]{trainedModelAssignment.getDeploymentId(), str});
                });
                fromAssignment.removeRoutingEntry(str);
            } else if (routingInfo != null && routingInfo.getState().isAnyOf(new RoutingState[]{RoutingState.STARTED, RoutingState.STARTING})) {
                logger.debug(() -> {
                    return Strings.format("[%s] Found assignment with route to shutting down node id [%s], adding stopping route", new Object[]{trainedModelAssignment.getDeploymentId(), str});
                });
                fromAssignment.addOrOverwriteRoutingEntry(str, TrainedModelAssignmentUtils.createShuttingDownRoute((RoutingInfo) trainedModelAssignment.getNodeRoutingTable().get(str)));
            }
        }
        return fromAssignment;
    }

    public void updateModelRoutingTable(final UpdateTrainedModelAssignmentRoutingInfoAction.Request request, final ActionListener<AcknowledgedResponse> actionListener) {
        logger.debug(() -> {
            return Strings.format("[%s] updating routing table entry for node [%s], update [%s]", new Object[]{request.getDeploymentId(), request.getNodeId(), request.getUpdate()});
        });
        submitUnbatchedTask("updating model routing for node assignment", new ClusterStateUpdateTask() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService.3
            public ClusterState execute(ClusterState clusterState) {
                return TrainedModelAssignmentClusterService.updateModelRoutingTable(clusterState, request);
            }

            public void onFailure(Exception exc) {
                actionListener.onFailure(exc);
            }

            public void clusterStateProcessed(ClusterState clusterState, ClusterState clusterState2) {
                actionListener.onResponse(AcknowledgedResponse.TRUE);
            }
        });
    }

    public void createNewModelAssignment(StartTrainedModelDeploymentAction.TaskParams taskParams, ActionListener<TrainedModelAssignment> actionListener) {
        if (this.clusterService.state().getMinTransportVersion().before(DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION)) {
            actionListener.onFailure(new ElasticsearchStatusException("cannot create new assignment [{}] for model [{}] while cluster upgrade is in progress", RestStatus.CONFLICT, new Object[]{taskParams.getDeploymentId(), taskParams.getModelId()}));
            return;
        }
        if (MlMetadata.getMlMetadata(this.clusterService.state()).isResetMode()) {
            actionListener.onFailure(new ElasticsearchStatusException("cannot create new assignment [{}] for model [{}] while feature reset is in progress.", RestStatus.CONFLICT, new Object[]{taskParams.getDeploymentId(), taskParams.getModelId()}));
            return;
        }
        ClusterState state = this.clusterService.state();
        Optional<StartTrainedModelDeploymentAction.TaskParams> of = Optional.of(taskParams);
        CheckedConsumer checkedConsumer = trainedModelAssignmentMetadata -> {
            TrainedModelAssignment deploymentAssignment = trainedModelAssignmentMetadata.getDeploymentAssignment(taskParams.getDeploymentId());
            if (deploymentAssignment == null) {
                deploymentAssignment = TrainedModelAssignment.Builder.empty(taskParams).build();
            }
            actionListener.onResponse(deploymentAssignment);
        };
        Objects.requireNonNull(actionListener);
        rebalanceAssignments(state, of, "model deployment started", ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    public void setModelAssignmentToStopping(final String str, final ActionListener<AcknowledgedResponse> actionListener) {
        submitUnbatchedTask("set model assignment stopping", new ClusterStateUpdateTask() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService.4
            public ClusterState execute(ClusterState clusterState) {
                return TrainedModelAssignmentClusterService.setToStopping(clusterState, str, "client API call");
            }

            public void onFailure(Exception exc) {
                actionListener.onFailure(exc);
            }

            public void clusterStateProcessed(ClusterState clusterState, ClusterState clusterState2) {
                actionListener.onResponse(AcknowledgedResponse.TRUE);
            }
        });
    }

    public void removeModelAssignment(final String str, final ActionListener<AcknowledgedResponse> actionListener) {
        submitUnbatchedTask("delete model deployment assignment", new ClusterStateUpdateTask() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService.5
            public ClusterState execute(ClusterState clusterState) {
                return TrainedModelAssignmentClusterService.removeAssignment(clusterState, str);
            }

            public void onFailure(Exception exc) {
                actionListener.onFailure(exc);
            }

            public void clusterStateProcessed(ClusterState clusterState, ClusterState clusterState2) {
                TrainedModelAssignmentClusterService trainedModelAssignmentClusterService = TrainedModelAssignmentClusterService.this;
                Optional<StartTrainedModelDeploymentAction.TaskParams> empty = Optional.empty();
                String str2 = str;
                CheckedConsumer checkedConsumer = trainedModelAssignmentMetadata -> {
                    TrainedModelAssignmentClusterService.logger.debug(() -> {
                        return Strings.format("Successfully rebalanced model deployments after deployment [%s] was stopped", new Object[]{str2});
                    });
                };
                String str3 = str;
                trainedModelAssignmentClusterService.rebalanceAssignments(clusterState2, empty, "model deployment stopped", ActionListener.wrap(checkedConsumer, exc -> {
                    TrainedModelAssignmentClusterService.logger.error(Strings.format("Failed to rebalance model deployments after deployment [%s] was stopped", new Object[]{str3}), exc);
                }));
                actionListener.onResponse(AcknowledgedResponse.TRUE);
            }
        });
    }

    public void removeAllModelAssignments(final ActionListener<AcknowledgedResponse> actionListener) {
        submitUnbatchedTask("delete all model assignments", new ClusterStateUpdateTask() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService.6
            public ClusterState execute(ClusterState clusterState) {
                return TrainedModelAssignmentClusterService.removeAllAssignments(clusterState);
            }

            public void onFailure(Exception exc) {
                actionListener.onFailure(exc);
            }

            public void clusterStateProcessed(ClusterState clusterState, ClusterState clusterState2) {
                actionListener.onResponse(AcknowledgedResponse.TRUE);
            }
        });
    }

    private static ClusterState update(ClusterState clusterState, TrainedModelAssignmentMetadata.Builder builder) {
        return builder.build().equals(TrainedModelAssignmentMetadata.fromState(clusterState)) ? clusterState : forceUpdate(clusterState, builder);
    }

    private static ClusterState forceUpdate(ClusterState clusterState, TrainedModelAssignmentMetadata.Builder builder) {
        logger.debug(() -> {
            return Strings.format("updated assignments: %s", new Object[]{builder.build()});
        });
        Metadata.Builder builder2 = Metadata.builder(clusterState.metadata());
        if (clusterState.getMinTransportVersion().onOrAfter(RENAME_ALLOCATION_TO_ASSIGNMENT_TRANSPORT_VERSION)) {
            builder2.putCustom("trained_model_assignment", builder.build()).removeCustom("trained_model_allocation");
        } else {
            builder2.putCustom("trained_model_allocation", builder.buildOld());
        }
        return ClusterState.builder(clusterState).metadata(builder2).build();
    }

    ClusterState createModelAssignment(ClusterState clusterState, StartTrainedModelDeploymentAction.TaskParams taskParams) throws Exception {
        return update(clusterState, rebalanceAssignments(clusterState, Optional.of(taskParams)));
    }

    private void rebalanceAssignments(ClusterState clusterState, Optional<StartTrainedModelDeploymentAction.TaskParams> optional, String str, ActionListener<TrainedModelAssignmentMetadata> actionListener) {
        CheckedConsumer checkedConsumer = set -> {
            this.threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
                logger.debug(() -> {
                    return Strings.format("Rebalancing model allocations because [%s]", new Object[]{str});
                });
                try {
                    final TrainedModelAssignmentMetadata.Builder rebalanceAssignments = rebalanceAssignments(clusterState, optional);
                    submitUnbatchedTask(str, new ClusterStateUpdateTask() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService.7
                        private volatile boolean isUpdated;
                        private volatile boolean isChanged;

                        public ClusterState execute(ClusterState clusterState2) {
                            ClusterState stopPlatformSpecificModelsInHeterogeneousClusters = TrainedModelAssignmentClusterService.this.stopPlatformSpecificModelsInHeterogeneousClusters(clusterState2, set, optional, clusterState);
                            if (!TrainedModelAssignmentClusterService.this.areClusterStatesCompatibleForRebalance(clusterState, stopPlatformSpecificModelsInHeterogeneousClusters)) {
                                TrainedModelAssignmentClusterService.this.rebalanceAssignments(stopPlatformSpecificModelsInHeterogeneousClusters, optional, str, actionListener);
                                return stopPlatformSpecificModelsInHeterogeneousClusters;
                            }
                            this.isUpdated = true;
                            ClusterState update = TrainedModelAssignmentClusterService.update(stopPlatformSpecificModelsInHeterogeneousClusters, rebalanceAssignments);
                            this.isChanged = update != stopPlatformSpecificModelsInHeterogeneousClusters;
                            return update;
                        }

                        public void onFailure(Exception exc) {
                            actionListener.onFailure(exc);
                        }

                        public void clusterStateProcessed(ClusterState clusterState2, ClusterState clusterState3) {
                            if (this.isUpdated) {
                                if (this.isChanged) {
                                    ExecutorService executor = TrainedModelAssignmentClusterService.this.threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME);
                                    String str2 = str;
                                    executor.execute(() -> {
                                        TrainedModelAssignmentClusterService.this.systemAuditor.info(Messages.getMessage("Rebalanced trained model allocations because [{0}]", new Object[]{str2}));
                                    });
                                }
                                actionListener.onResponse(TrainedModelAssignmentMetadata.fromState(clusterState3));
                            }
                        }
                    });
                } catch (Exception e) {
                    actionListener.onFailure(e);
                }
            });
        };
        Objects.requireNonNull(actionListener);
        MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet(ActionListener.wrap(checkedConsumer, actionListener::onFailure), this.client, this.threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME));
    }

    ClusterState stopPlatformSpecificModelsInHeterogeneousClusters(ClusterState clusterState, Set<String> set, Optional<StartTrainedModelDeploymentAction.TaskParams> optional, ClusterState clusterState2) {
        if (set.size() > 1 && optional.isPresent()) {
            clusterState = callSetToStopping(Strings.format("ML nodes in this cluster have multiple platform architectures, but can only have one for this model ([%s]); detected architectures: %s", new Object[]{optional.get().getModelId(), set}), optional.get().getDeploymentId(), clusterState2);
        }
        return clusterState;
    }

    ClusterState callSetToStopping(String str, String str2, ClusterState clusterState) {
        return setToStopping(clusterState, str2, str);
    }

    private boolean areClusterStatesCompatibleForRebalance(ClusterState clusterState, ClusterState clusterState2) {
        List<DiscoveryNode> assignableNodes = getAssignableNodes(clusterState);
        List<DiscoveryNode> assignableNodes2 = getAssignableNodes(clusterState2);
        return assignableNodes.equals(assignableNodes2) && detectNodeLoads(assignableNodes, clusterState).equals(detectNodeLoads(assignableNodes2, clusterState2)) && MlMetadata.getMlMetadata(clusterState).equals(MlMetadata.getMlMetadata(clusterState2)) && TrainedModelAssignmentMetadata.fromState(clusterState).equals(TrainedModelAssignmentMetadata.fromState(clusterState2));
    }

    private TrainedModelAssignmentMetadata.Builder rebalanceAssignments(ClusterState clusterState, Optional<StartTrainedModelDeploymentAction.TaskParams> optional) throws Exception {
        List<DiscoveryNode> assignableNodes = getAssignableNodes(clusterState);
        logger.debug(() -> {
            return Strings.format("assignable nodes are %s", new Object[]{assignableNodes.stream().map((v0) -> {
                return v0.getId();
            }).toList()});
        });
        Map<DiscoveryNode, NodeLoad> detectNodeLoads = detectNodeLoads(assignableNodes, clusterState);
        TrainedModelAssignmentMetadata fromState = TrainedModelAssignmentMetadata.fromState(clusterState);
        TrainedModelAssignmentMetadata.Builder shuttingDownNodeRoutesToStopping = setShuttingDownNodeRoutesToStopping(fromState, clusterState.metadata().nodeShutdowns().getAllNodeIds(), new TrainedModelAssignmentRebalancer(fromState, detectNodeLoads, this.nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(clusterState), optional, this.allocatedProcessorsScale, TrainedModelAssignment.useNewMemoryFields(TransportVersionUtils.getMinTransportVersion(clusterState))).rebalance());
        if (optional.isPresent()) {
            checkModelIsFullyAllocatedIfScalingIsNotPossible(optional.get().getDeploymentId(), shuttingDownNodeRoutesToStopping, assignableNodes);
        }
        return shuttingDownNodeRoutesToStopping;
    }

    static TrainedModelAssignmentMetadata.Builder setShuttingDownNodeRoutesToStopping(TrainedModelAssignmentMetadata trainedModelAssignmentMetadata, Set<String> set, TrainedModelAssignmentMetadata.Builder builder) {
        if (set.isEmpty()) {
            return builder;
        }
        for (TrainedModelAssignment trainedModelAssignment : trainedModelAssignmentMetadata.allAssignments().values()) {
            boolean z = false;
            String deploymentId = trainedModelAssignment.getDeploymentId();
            TrainedModelAssignment.Builder assignment = builder.hasModelDeployment(trainedModelAssignment.getDeploymentId()) ? builder.getAssignment(deploymentId) : TrainedModelAssignment.Builder.fromAssignment(trainedModelAssignment).stopAssignment("nodes changed").clearNodeRoutingTable();
            for (String str : set) {
                if (trainedModelAssignment.isRoutedToNode(str) && ((RoutingInfo) trainedModelAssignment.getNodeRoutingTable().get(str)).getState().isAnyOf(new RoutingState[]{RoutingState.STARTED, RoutingState.STARTING})) {
                    logger.debug(() -> {
                        return Strings.format("Found assignment deployment id: [%s] with route to shutting down node id: [%s], adding stopping route", new Object[]{deploymentId, str});
                    });
                    z = true;
                    assignment.addOrOverwriteRoutingEntry(str, TrainedModelAssignmentUtils.createShuttingDownRoute((RoutingInfo) trainedModelAssignment.getNodeRoutingTable().get(str)));
                }
            }
            if (z) {
                builder.addOrOverwriteAssignment(deploymentId, assignment);
            }
        }
        return builder;
    }

    private void checkModelIsFullyAllocatedIfScalingIsNotPossible(String str, TrainedModelAssignmentMetadata.Builder builder, List<DiscoveryNode> list) {
        TrainedModelAssignment build = builder.getAssignment(str).build();
        if (isScalingPossible(list) || build.isSatisfied((Set) list.stream().map((v0) -> {
            return v0.getId();
        }).collect(Collectors.toSet()))) {
            return;
        }
        if (!build.getNodeRoutingTable().isEmpty()) {
            String str2 = "Could not start deployment because there are not enough resources to provide all requested allocations";
            logger.debug(() -> {
                return Strings.format("[%s] %s", new Object[]{str, str2});
            });
            throw new ElasticsearchStatusException("Could not start deployment because there are not enough resources to provide all requested allocations", RestStatus.TOO_MANY_REQUESTS, new Object[0]);
        }
        String str3 = "Could not start deployment because no suitable nodes were found, allocation explanation [" + ((String) build.getReason().orElse("none")) + "]";
        logger.warn("[{}] {}", str, str3);
        throw new ElasticsearchStatusException("Could not start deployment because no ML nodes with sufficient capacity were found", RestStatus.TOO_MANY_REQUESTS, new IllegalStateException(str3), new Object[0]);
    }

    private static List<DiscoveryNode> getAssignableNodes(ClusterState clusterState) {
        Set<String> nodesShuttingDown = nodesShuttingDown(clusterState);
        return clusterState.getNodes().getNodes().values().stream().filter(StartTrainedModelDeploymentAction.TaskParams::mayAssignToNode).filter(discoveryNode -> {
            return !nodesShuttingDown.contains(discoveryNode.getId());
        }).toList();
    }

    private Map<DiscoveryNode, NodeLoad> detectNodeLoads(List<DiscoveryNode> list, ClusterState clusterState) {
        return (Map) list.stream().collect(Collectors.toMap(Function.identity(), discoveryNode -> {
            return this.nodeLoadDetector.detectNodeLoad(clusterState, null, discoveryNode, this.maxOpenJobs, this.maxMemoryPercentage, this.useAuto);
        }));
    }

    private boolean isScalingPossible(List<DiscoveryNode> list) {
        OptionalLong min = list.stream().map(NodeLoadDetector::getNodeSize).flatMapToLong((v0) -> {
            return v0.stream();
        }).min();
        return this.maxLazyMLNodes > list.size() || (min.isPresent() && min.getAsLong() < this.maxMLNodeSize);
    }

    public void updateNumberOfAllocations(String str, int i, ActionListener<TrainedModelAssignment> actionListener) {
        updateNumberOfAllocations(this.clusterService.state(), str, i, actionListener);
    }

    private void updateNumberOfAllocations(ClusterState clusterState, String str, int i, ActionListener<TrainedModelAssignment> actionListener) {
        TrainedModelAssignment deploymentAssignment = TrainedModelAssignmentMetadata.fromState(clusterState).getDeploymentAssignment(str);
        if (deploymentAssignment == null) {
            actionListener.onFailure(ExceptionsHelper.missingModelDeployment(str));
            return;
        }
        if (deploymentAssignment.getTaskParams().getNumberOfAllocations() == i) {
            actionListener.onResponse(deploymentAssignment);
            return;
        }
        if (deploymentAssignment.getAssignmentState() != AssignmentState.STARTED) {
            actionListener.onFailure(new ElasticsearchStatusException("cannot update deployment that is not in [{}] state", RestStatus.CONFLICT, new Object[]{AssignmentState.STARTED}));
        } else {
            if (clusterState.getMinTransportVersion().before(DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION)) {
                actionListener.onFailure(new ElasticsearchStatusException("cannot update number_of_allocations for deployment with model id [{}] while cluster upgrade is in progress.", RestStatus.CONFLICT, new Object[]{str}));
                return;
            }
            CheckedConsumer checkedConsumer = clusterState2 -> {
                submitUnbatchedTask("update model deployment number_of_allocations", new ClusterStateUpdateTask() { // from class: org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService.8
                    private volatile boolean isUpdated;

                    public ClusterState execute(ClusterState clusterState2) {
                        if (TrainedModelAssignmentClusterService.this.areClusterStatesCompatibleForRebalance(clusterState, clusterState2)) {
                            this.isUpdated = true;
                            return clusterState2;
                        }
                        Logger logger2 = TrainedModelAssignmentClusterService.logger;
                        String str2 = str;
                        logger2.debug(() -> {
                            return Strings.format("[%s] Retrying update as cluster state has been modified", new Object[]{str2});
                        });
                        TrainedModelAssignmentClusterService.this.updateNumberOfAllocations(clusterState2, str, i, actionListener);
                        return clusterState2;
                    }

                    public void onFailure(Exception exc) {
                        actionListener.onFailure(exc);
                    }

                    public void clusterStateProcessed(ClusterState clusterState2, ClusterState clusterState3) {
                        if (this.isUpdated) {
                            TrainedModelAssignment deploymentAssignment2 = TrainedModelAssignmentMetadata.fromState(clusterState3).getDeploymentAssignment(str);
                            if (deploymentAssignment2.totalTargetAllocations() > deploymentAssignment.totalTargetAllocations()) {
                                TrainedModelAssignmentClusterService.this.threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
                                    TrainedModelAssignmentClusterService.this.systemAuditor.info(Messages.getMessage("Rebalanced trained model allocations because [{0}]", new Object[]{"model deployment updated"}));
                                });
                            }
                            actionListener.onResponse(deploymentAssignment2);
                        }
                    }
                });
            };
            Objects.requireNonNull(actionListener);
            adjustNumberOfAllocations(clusterState, deploymentAssignment, i, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
        }
    }

    private void adjustNumberOfAllocations(ClusterState clusterState, TrainedModelAssignment trainedModelAssignment, int i, ActionListener<ClusterState> actionListener) {
        this.threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
            if (i > trainedModelAssignment.getTaskParams().getNumberOfAllocations()) {
                increaseNumberOfAllocations(clusterState, trainedModelAssignment, i, actionListener);
            } else {
                decreaseNumberOfAllocations(clusterState, trainedModelAssignment, i, actionListener);
            }
        });
    }

    private void increaseNumberOfAllocations(ClusterState clusterState, TrainedModelAssignment trainedModelAssignment, int i, ActionListener<ClusterState> actionListener) {
        try {
            TrainedModelAssignmentMetadata.Builder rebalanceAssignments = rebalanceAssignments(update(clusterState, TrainedModelAssignmentMetadata.builder(clusterState).updateAssignment(trainedModelAssignment.getDeploymentId(), TrainedModelAssignment.Builder.fromAssignment(trainedModelAssignment).setNumberOfAllocations(i))), Optional.empty());
            if (isScalingPossible(getAssignableNodes(clusterState)) || rebalanceAssignments.getAssignment(trainedModelAssignment.getDeploymentId()).build().totalTargetAllocations() >= i) {
                actionListener.onResponse(update(clusterState, rebalanceAssignments));
            } else {
                actionListener.onFailure(new ElasticsearchStatusException("Could not update deployment because there are not enough resources to provide all requested allocations", RestStatus.TOO_MANY_REQUESTS, new Object[0]));
            }
        } catch (Exception e) {
            actionListener.onFailure(e);
        }
    }

    private void decreaseNumberOfAllocations(ClusterState clusterState, TrainedModelAssignment trainedModelAssignment, int i, ActionListener<ClusterState> actionListener) {
        TrainedModelAssignment.Builder reduceTo = i < trainedModelAssignment.totalTargetAllocations() ? new AllocationReducer(trainedModelAssignment, this.nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(clusterState)).reduceTo(i) : TrainedModelAssignment.Builder.fromAssignment(trainedModelAssignment).setNumberOfAllocations(i);
        if (i <= trainedModelAssignment.totalTargetAllocations()) {
            reduceTo.setReason((String) null);
        }
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState);
        builder.updateAssignment(trainedModelAssignment.getDeploymentId(), reduceTo);
        actionListener.onResponse(update(clusterState, builder));
    }

    static ClusterState setToStopping(ClusterState clusterState, String str, String str2) {
        TrainedModelAssignment deploymentAssignment = TrainedModelAssignmentMetadata.fromState(clusterState).getDeploymentAssignment(str);
        if (deploymentAssignment == null) {
            throw new ResourceNotFoundException("assignment with id [{}] not found", new Object[]{str});
        }
        if (deploymentAssignment.getAssignmentState().equals(AssignmentState.STOPPING)) {
            return clusterState;
        }
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState);
        builder.getAssignment(str).stopAssignment(str2);
        return update(clusterState, builder);
    }

    static ClusterState updateModelRoutingTable(ClusterState clusterState, UpdateTrainedModelAssignmentRoutingInfoAction.Request request) {
        String deploymentId = request.getDeploymentId();
        String nodeId = request.getNodeId();
        TrainedModelAssignmentMetadata fromState = TrainedModelAssignmentMetadata.fromState(clusterState);
        logger.trace(() -> {
            return Strings.format("[%s] [%s] current metadata before update %s", new Object[]{deploymentId, nodeId, org.elasticsearch.common.Strings.toString(fromState)});
        });
        TrainedModelAssignment deploymentAssignment = fromState.getDeploymentAssignment(deploymentId);
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState);
        if (request.getUpdate().getStateAndReason().isPresent() && ((RoutingStateAndReason) request.getUpdate().getStateAndReason().get()).getState().equals(RoutingState.STOPPED)) {
            if (deploymentAssignment == null || !deploymentAssignment.isRoutedToNode(nodeId)) {
                return clusterState;
            }
            builder.getAssignment(deploymentId).removeRoutingEntry(nodeId).calculateAndSetAssignmentState();
            return update(clusterState, builder);
        }
        if (deploymentAssignment == null) {
            throw new ResourceNotFoundException("assignment with id [{}] not found", new Object[]{deploymentId});
        }
        if (deploymentAssignment.getAssignmentState().equals(AssignmentState.STOPPING)) {
            logger.debug(() -> {
                return Strings.format("[%s] requested update from node [%s] while stopping; update was [%s]", new Object[]{deploymentId, nodeId, request.getUpdate()});
            });
            return clusterState;
        }
        if (!deploymentAssignment.isRoutedToNode(nodeId)) {
            throw new ResourceNotFoundException("assignment with id [{}]] is not routed to node [{}]", new Object[]{deploymentId, nodeId});
        }
        builder.getAssignment(deploymentId).updateExistingRoutingEntry(nodeId, request.getUpdate().apply((RoutingInfo) deploymentAssignment.getNodeRoutingTable().get(nodeId))).calculateAndSetAssignmentState();
        return update(clusterState, builder);
    }

    static ClusterState removeAssignment(ClusterState clusterState, String str) {
        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState);
        if (!builder.hasModelDeployment(str)) {
            throw new ResourceNotFoundException("assignment for deployment with id [{}] not found", new Object[]{str});
        }
        logger.debug(() -> {
            return Strings.format("[%s] removing assignment", new Object[]{str});
        });
        return update(clusterState, builder.removeAssignment(str));
    }

    static ClusterState removeAllAssignments(ClusterState clusterState) {
        return TrainedModelAssignmentMetadata.fromState(clusterState).allAssignments().isEmpty() ? clusterState : forceUpdate(clusterState, TrainedModelAssignmentMetadata.Builder.empty());
    }

    static Optional<String> detectReasonToRebalanceModels(ClusterChangedEvent clusterChangedEvent) {
        TrainedModelAssignmentMetadata fromState = TrainedModelAssignmentMetadata.fromState(clusterChangedEvent.state());
        return (fromState == null || fromState.allAssignments().isEmpty()) ? Optional.empty() : detectReasonIfMlJobsStopped(clusterChangedEvent).or(() -> {
            Object obj = null;
            if (haveMlNodesChanged(clusterChangedEvent, fromState)) {
                obj = "nodes changed";
            } else if (fromState.hasOutdatedAssignments()) {
                obj = "outdated assignments detected";
            }
            return Optional.ofNullable(obj);
        });
    }

    static Optional<String> detectReasonIfMlJobsStopped(ClusterChangedEvent clusterChangedEvent) {
        if (!clusterChangedEvent.changedCustomMetadataSet().contains("persistent_tasks")) {
            return Optional.empty();
        }
        PersistentTasksCustomMetadata custom = clusterChangedEvent.previousState().getMetadata().custom("persistent_tasks");
        PersistentTasksCustomMetadata custom2 = clusterChangedEvent.state().getMetadata().custom("persistent_tasks");
        Set<String> findMlProcessTaskIds = findMlProcessTaskIds(custom);
        Set<String> findMlProcessTaskIds2 = findMlProcessTaskIds(custom2);
        Stream<String> filter = findMlProcessTaskIds.stream().filter(str -> {
            return !findMlProcessTaskIds2.contains(str);
        });
        Objects.requireNonNull(custom);
        Set set = (Set) filter.map(custom::getTask).map((v0) -> {
            return v0.getTaskName();
        }).map(MlTasks::prettyPrintTaskName).collect(Collectors.toSet());
        return set.size() == 1 ? Optional.of("ML [" + ((String) set.iterator().next()) + "] job stopped") : set.size() > 1 ? Optional.of("ML " + set + " jobs stopped") : Optional.empty();
    }

    private static Set<String> findMlProcessTaskIds(@Nullable PersistentTasksCustomMetadata persistentTasksCustomMetadata) {
        return persistentTasksCustomMetadata == null ? Set.of() : (Set) MlTasks.findMlProcessTasks(persistentTasksCustomMetadata).stream().map((v0) -> {
            return v0.getId();
        }).collect(Collectors.toSet());
    }

    static boolean haveMlNodesChanged(ClusterChangedEvent clusterChangedEvent, TrainedModelAssignmentMetadata trainedModelAssignmentMetadata) {
        Set<String> emptySet;
        boolean contains = clusterChangedEvent.changedCustomMetadataSet().contains("node_shutdown");
        if (!clusterChangedEvent.nodesChanged() && !contains) {
            return false;
        }
        String hexString = Long.toHexString(System.nanoTime());
        Set<String> nodesShuttingDown = nodesShuttingDown(clusterChangedEvent.state());
        DiscoveryNodes.Delta nodesDelta = clusterChangedEvent.nodesDelta();
        Set<String> set = (Set) nodesDelta.removedNodes().stream().map((v0) -> {
            return v0.getId();
        }).collect(Collectors.toSet());
        Set<String> set2 = (Set) nodesDelta.addedNodes().stream().map((v0) -> {
            return v0.getId();
        }).collect(Collectors.toSet());
        logger.debug(() -> {
            return Strings.format("Initial node change info; identity: %s; removed nodes: %s; added nodes: %s; shutting down nodes: %s", new Object[]{hexString, set, set2, nodesShuttingDown});
        });
        if (contains) {
            Set<String> nodesShuttingDown2 = nodesShuttingDown(clusterChangedEvent.previousState());
            Set intersection = Sets.intersection((Set) clusterChangedEvent.state().nodes().stream().map((v0) -> {
                return v0.getId();
            }).collect(Collectors.toSet()), Sets.difference(nodesShuttingDown2, nodesShuttingDown));
            set2.addAll(intersection);
            emptySet = Sets.difference(nodesShuttingDown, nodesShuttingDown2);
            set.addAll(emptySet);
            logger.debug(() -> {
                return Strings.format("Shutting down nodes were changed; identity: %s; previous shutting down nodes: %s; returning nodes: %s", new Object[]{hexString, nodesShuttingDown2, intersection});
            });
        } else {
            emptySet = Collections.emptySet();
        }
        Set set3 = emptySet;
        logger.debug(() -> {
            return Strings.format("identity: %s; added nodes %s; removed nodes %s; shutting down nodes %s; exiting shutdown nodes %s", new Object[]{hexString, set2, set, nodesShuttingDown, set3});
        });
        for (TrainedModelAssignment trainedModelAssignment : trainedModelAssignmentMetadata.allAssignments().values()) {
            if (!trainedModelAssignment.getAssignmentState().equals(AssignmentState.STOPPING)) {
                for (String str : emptySet) {
                    if (trainedModelAssignment.isRoutedToNode(str) && ((RoutingInfo) trainedModelAssignment.getNodeRoutingTable().get(str)).getState() != RoutingState.STOPPING) {
                        logger.debug(() -> {
                            return Strings.format("should rebalance because model deployment [%s] has allocations on shutting down node [%s]", new Object[]{trainedModelAssignment.getDeploymentId(), str});
                        });
                        return true;
                    }
                }
                for (String str2 : set) {
                    if (trainedModelAssignment.isRoutedToNode(str2) && !nodesShuttingDown.contains(str2)) {
                        logger.debug(() -> {
                            return Strings.format("should rebalance because model deployment [%s] has allocations on removed node [%s]", new Object[]{trainedModelAssignment.getDeploymentId(), str2});
                        });
                        return true;
                    }
                }
                for (String str3 : set2) {
                    if (StartTrainedModelDeploymentAction.TaskParams.mayAssignToNode(clusterChangedEvent.state().nodes().get(str3)) && !nodesShuttingDown.contains(str3)) {
                        logger.debug(() -> {
                            return Strings.format("should rebalance because ML eligible node [%s] was added", new Object[]{str3});
                        });
                        return true;
                    }
                }
            }
        }
        return false;
    }

    static Set<String> nodesShuttingDown(ClusterState clusterState) {
        return clusterState.metadata().nodeShutdowns().getAllNodeIds();
    }
}
