package org.elasticsearch.xpack.ml.inference.assignment;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Strings;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlanner;
import org.elasticsearch.xpack.ml.inference.assignment.planning.ZoneAwareAssignmentPlanner;
import org.elasticsearch.xpack.ml.job.NodeLoad;
import org.elasticsearch.xpack.ml.utils.MlProcessors;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.class */
class TrainedModelAssignmentRebalancer {
    private static final Logger logger = LogManager.getLogger(TrainedModelAssignmentRebalancer.class);
    private final TrainedModelAssignmentMetadata currentMetadata;
    private final Map<DiscoveryNode, NodeLoad> nodeLoads;
    private final Map<List<String>, Collection<DiscoveryNode>> mlNodesByZone;
    private final Optional<StartTrainedModelDeploymentAction.TaskParams> deploymentToAdd;
    private final int allocatedProcessorsScale;
    private final boolean useNewMemoryFields;

    /* JADX INFO: Access modifiers changed from: package-private */
    public TrainedModelAssignmentRebalancer(TrainedModelAssignmentMetadata trainedModelAssignmentMetadata, Map<DiscoveryNode, NodeLoad> map, Map<List<String>, Collection<DiscoveryNode>> map2, Optional<StartTrainedModelDeploymentAction.TaskParams> optional, int i, boolean z) {
        this.currentMetadata = (TrainedModelAssignmentMetadata) Objects.requireNonNull(trainedModelAssignmentMetadata);
        this.nodeLoads = (Map) Objects.requireNonNull(map);
        this.mlNodesByZone = (Map) Objects.requireNonNull(map2);
        this.deploymentToAdd = (Optional) Objects.requireNonNull(optional);
        this.allocatedProcessorsScale = i;
        this.useNewMemoryFields = z;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public TrainedModelAssignmentMetadata.Builder rebalance() {
        if (this.deploymentToAdd.isPresent() && this.currentMetadata.hasDeployment(this.deploymentToAdd.get().getDeploymentId())) {
            throw new ResourceAlreadyExistsException("[{}] assignment for deployment with model [{}] already exists", new Object[]{this.deploymentToAdd.get().getDeploymentId(), this.deploymentToAdd.get().getModelId()});
        }
        if (!this.deploymentToAdd.isEmpty() || !areAllModelsSatisfiedAndNoOutdatedRoutingEntries()) {
            return buildAssignmentsFromPlan(computeAssignmentPlan());
        }
        logger.trace(() -> {
            return "No need to rebalance as all model deployments are satisfied";
        });
        return TrainedModelAssignmentMetadata.Builder.fromMetadata(this.currentMetadata);
    }

    private boolean areAllModelsSatisfiedAndNoOutdatedRoutingEntries() {
        Set set = (Set) this.nodeLoads.keySet().stream().map((v0) -> {
            return v0.getId();
        }).collect(Collectors.toSet());
        for (TrainedModelAssignment trainedModelAssignment : this.currentMetadata.allAssignments().values()) {
            if (!trainedModelAssignment.isSatisfied(set) || trainedModelAssignment.hasOutdatedRoutingEntries()) {
                return false;
            }
        }
        return true;
    }

    AssignmentPlan computeAssignmentPlan() {
        Map<List<String>, List<AssignmentPlan.Node>> createNodesByZoneMap = createNodesByZoneMap();
        Set<String> set = (Set) createNodesByZoneMap.values().stream().flatMap((v0) -> {
            return v0.stream();
        }).map((v0) -> {
            return v0.id();
        }).collect(Collectors.toSet());
        AssignmentPlan computePlanForNormalPriorityModels = computePlanForNormalPriorityModels(createNodesByZoneMap, set);
        return mergePlans(createNodesByZoneMap, computePlanForNormalPriorityModels, computePlanForLowPriorityModels(set, computePlanForNormalPriorityModels));
    }

    private static AssignmentPlan mergePlans(Map<List<String>, List<AssignmentPlan.Node>> map, AssignmentPlan assignmentPlan, AssignmentPlan assignmentPlan2) {
        ArrayList arrayList = new ArrayList();
        Collection<List<AssignmentPlan.Node>> values = map.values();
        Objects.requireNonNull(arrayList);
        values.forEach((v1) -> {
            r1.addAll(v1);
        });
        ArrayList arrayList2 = new ArrayList();
        arrayList2.addAll(assignmentPlan.models());
        arrayList2.addAll(assignmentPlan2.models());
        Map map2 = (Map) arrayList.stream().collect(Collectors.toMap((v0) -> {
            return v0.id();
        }, Function.identity()));
        AssignmentPlan.Builder builder = AssignmentPlan.builder(arrayList, arrayList2);
        copyAssignments(assignmentPlan, builder, map2);
        copyAssignments(assignmentPlan2, builder, map2);
        return builder.build();
    }

    private static void copyAssignments(AssignmentPlan assignmentPlan, AssignmentPlan.Builder builder, Map<String, AssignmentPlan.Node> map) {
        for (AssignmentPlan.Deployment deployment : assignmentPlan.models()) {
            for (Map.Entry<AssignmentPlan.Node, Integer> entry : assignmentPlan.assignments(deployment).orElse(Map.of()).entrySet()) {
                AssignmentPlan.Node node = map.get(entry.getKey().id());
                builder.assignModelToNode(deployment, node, entry.getValue().intValue());
                builder.accountMemory(deployment, node);
            }
        }
    }

    private AssignmentPlan computePlanForNormalPriorityModels(Map<List<String>, List<AssignmentPlan.Node>> map, Set<String> set) {
        ArrayList arrayList = new ArrayList();
        Stream map2 = this.currentMetadata.allAssignments().values().stream().filter(trainedModelAssignment -> {
            return trainedModelAssignment.getTaskParams().getPriority() != Priority.LOW;
        }).map(trainedModelAssignment2 -> {
            return new AssignmentPlan.Deployment(trainedModelAssignment2.getDeploymentId(), trainedModelAssignment2.getTaskParams().getModelBytes(), trainedModelAssignment2.getTaskParams().getNumberOfAllocations(), trainedModelAssignment2.getTaskParams().getThreadsPerAllocation(), (Map) trainedModelAssignment2.getNodeRoutingTable().entrySet().stream().filter(entry -> {
                return set.contains(entry.getKey());
            }).filter(entry2 -> {
                return ((RoutingInfo) entry2.getValue()).getCurrentAllocations() > 0 && ((RoutingInfo) entry2.getValue()).getTargetAllocations() > 0;
            }).filter(entry3 -> {
                return ((RoutingInfo) entry3.getValue()).getState().isAnyOf(new RoutingState[]{RoutingState.STARTING, RoutingState.STARTED, RoutingState.FAILED});
            }).collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, entry4 -> {
                return Integer.valueOf(((RoutingInfo) entry4.getValue()).getTargetAllocations());
            })), trainedModelAssignment2.getMaxAssignedAllocations(), this.useNewMemoryFields ? trainedModelAssignment2.getTaskParams().getPerDeploymentMemoryBytes() : 0L, this.useNewMemoryFields ? trainedModelAssignment2.getTaskParams().getPerAllocationMemoryBytes() : 0L);
        });
        Objects.requireNonNull(arrayList);
        map2.forEach((v1) -> {
            r1.add(v1);
        });
        if (this.deploymentToAdd.isPresent() && this.deploymentToAdd.get().getPriority() != Priority.LOW) {
            StartTrainedModelDeploymentAction.TaskParams taskParams = this.deploymentToAdd.get();
            arrayList.add(new AssignmentPlan.Deployment(taskParams.getDeploymentId(), taskParams.getModelBytes(), taskParams.getNumberOfAllocations(), taskParams.getThreadsPerAllocation(), Map.of(), 0, this.useNewMemoryFields ? taskParams.getPerDeploymentMemoryBytes() : 0L, this.useNewMemoryFields ? taskParams.getPerAllocationMemoryBytes() : 0L));
        }
        return new ZoneAwareAssignmentPlanner(map, arrayList).computePlan();
    }

    private AssignmentPlan computePlanForLowPriorityModels(Set<String> set, AssignmentPlan assignmentPlan) {
        List list = this.mlNodesByZone.values().stream().flatMap((v0) -> {
            return v0.stream();
        }).map(discoveryNode -> {
            return new AssignmentPlan.Node(discoveryNode.getId(), assignmentPlan.getRemainingNodeMemory(discoveryNode.getId()), 100);
        }).toList();
        HashMap hashMap = new HashMap();
        list.forEach(node -> {
            hashMap.put(node.id(), Long.valueOf(node.availableMemoryBytes()));
        });
        ArrayList arrayList = new ArrayList();
        Stream map = this.currentMetadata.allAssignments().values().stream().filter(trainedModelAssignment -> {
            return trainedModelAssignment.getTaskParams().getPriority() == Priority.LOW;
        }).sorted(Comparator.comparingLong(trainedModelAssignment2 -> {
            return trainedModelAssignment2.getTaskParams().estimateMemoryUsageBytes();
        })).map(trainedModelAssignment3 -> {
            return new AssignmentPlan.Deployment(trainedModelAssignment3.getDeploymentId(), trainedModelAssignment3.getTaskParams().getModelBytes(), trainedModelAssignment3.getTaskParams().getNumberOfAllocations(), trainedModelAssignment3.getTaskParams().getThreadsPerAllocation(), findFittingAssignments(trainedModelAssignment3, set, hashMap), trainedModelAssignment3.getMaxAssignedAllocations(), Priority.LOW, !this.useNewMemoryFields ? trainedModelAssignment3.getTaskParams().getPerDeploymentMemoryBytes() : 0L, !this.useNewMemoryFields ? trainedModelAssignment3.getTaskParams().getPerAllocationMemoryBytes() : 0L);
        });
        Objects.requireNonNull(arrayList);
        map.forEach((v1) -> {
            r1.add(v1);
        });
        if (this.deploymentToAdd.isPresent() && this.deploymentToAdd.get().getPriority() == Priority.LOW) {
            StartTrainedModelDeploymentAction.TaskParams taskParams = this.deploymentToAdd.get();
            arrayList.add(new AssignmentPlan.Deployment(taskParams.getDeploymentId(), taskParams.getModelBytes(), taskParams.getNumberOfAllocations(), taskParams.getThreadsPerAllocation(), Map.of(), 0, Priority.LOW, !this.useNewMemoryFields ? taskParams.getPerDeploymentMemoryBytes() : 0L, !this.useNewMemoryFields ? taskParams.getPerAllocationMemoryBytes() : 0L));
        }
        logger.debug(() -> {
            return Strings.format("Computing plan for low priority deployments. CPU cores fixed to [%s].", new Object[]{100});
        });
        return new AssignmentPlanner(list, arrayList).computePlan();
    }

    private static Map<String, Integer> findFittingAssignments(TrainedModelAssignment trainedModelAssignment, Set<String> set, Map<String, Long> map) {
        Map map2 = (Map) trainedModelAssignment.getNodeRoutingTable().entrySet().stream().filter(entry -> {
            return set.contains(entry.getKey());
        }).filter(entry2 -> {
            return ((RoutingInfo) entry2.getValue()).getState().isAnyOf(new RoutingState[]{RoutingState.STARTING, RoutingState.STARTED, RoutingState.FAILED});
        }).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry3 -> {
            return Integer.valueOf(((RoutingInfo) entry3.getValue()).getTargetAllocations());
        }));
        long estimateMemoryUsageBytes = trainedModelAssignment.getTaskParams().estimateMemoryUsageBytes();
        HashMap hashMap = new HashMap();
        map2.entrySet().stream().filter(entry4 -> {
            return ((Integer) entry4.getValue()).intValue() > 0;
        }).forEach(entry5 -> {
            if (((Long) map.get(entry5.getKey())).longValue() >= estimateMemoryUsageBytes) {
                hashMap.put((String) entry5.getKey(), (Integer) entry5.getValue());
                map.computeIfPresent((String) entry5.getKey(), (str, l) -> {
                    return Long.valueOf(l.longValue() - estimateMemoryUsageBytes);
                });
            }
        });
        return hashMap;
    }

    private Map<List<String>, List<AssignmentPlan.Node>> createNodesByZoneMap() {
        return (Map) this.mlNodesByZone.entrySet().stream().collect(Collectors.toMap(entry -> {
            return (List) entry.getKey();
        }, entry2 -> {
            Collection<DiscoveryNode> collection = (Collection) entry2.getValue();
            ArrayList arrayList = new ArrayList();
            for (DiscoveryNode discoveryNode : collection) {
                if (this.nodeLoads.containsKey(discoveryNode)) {
                    NodeLoad nodeLoad = this.nodeLoads.get(discoveryNode);
                    if (org.elasticsearch.common.Strings.isNullOrEmpty(nodeLoad.getError())) {
                        arrayList.add(new AssignmentPlan.Node(discoveryNode.getId(), getNodeFreeMemoryExcludingPerNodeOverheadAndNativeInference(nodeLoad), MlProcessors.get(discoveryNode, Integer.valueOf(this.allocatedProcessorsScale)).roundUp()));
                    } else {
                        logger.warn(Strings.format("ignoring node [%s] as detecting its load failed with [%s]", new Object[]{discoveryNode.getId(), nodeLoad.getError()}));
                    }
                } else {
                    logger.warn(Strings.format("ignoring node [%s] as no load could be detected", new Object[]{discoveryNode.getId()}));
                }
            }
            return arrayList;
        }));
    }

    private static long getNodeFreeMemoryExcludingPerNodeOverheadAndNativeInference(NodeLoad nodeLoad) {
        return nodeLoad.getFreeMemoryExcludingPerNodeOverhead() - nodeLoad.getAssignedNativeInferenceMemory();
    }

    private TrainedModelAssignmentMetadata.Builder buildAssignmentsFromPlan(AssignmentPlan assignmentPlan) {
        TrainedModelAssignmentMetadata.Builder empty = TrainedModelAssignmentMetadata.Builder.empty();
        for (AssignmentPlan.Deployment deployment : assignmentPlan.models()) {
            TrainedModelAssignment deploymentAssignment = this.currentMetadata.getDeploymentAssignment(deployment.id());
            TrainedModelAssignment.Builder empty2 = TrainedModelAssignment.Builder.empty((deploymentAssignment == null && this.deploymentToAdd.isPresent()) ? this.deploymentToAdd.get() : this.currentMetadata.getDeploymentAssignment(deployment.id()).getTaskParams());
            if (deploymentAssignment != null) {
                empty2.setStartTime(deploymentAssignment.getStartTime());
                empty2.setMaxAssignedAllocations(deploymentAssignment.getMaxAssignedAllocations());
            }
            for (Map.Entry<AssignmentPlan.Node, Integer> entry : assignmentPlan.assignments(deployment).orElseGet(Map::of).entrySet()) {
                if (deploymentAssignment == null || !deploymentAssignment.isRoutedToNode(entry.getKey().id())) {
                    empty2.addRoutingEntry(entry.getKey().id(), new RoutingInfo(entry.getValue().intValue(), entry.getValue().intValue(), RoutingState.STARTING, ""));
                } else {
                    RoutingInfo routingInfo = (RoutingInfo) deploymentAssignment.getNodeRoutingTable().get(entry.getKey().id());
                    RoutingState state = routingInfo.getState();
                    String reason = routingInfo.getReason();
                    if (state == RoutingState.FAILED) {
                        state = RoutingState.STARTING;
                        reason = "";
                    }
                    empty2.addRoutingEntry(entry.getKey().id(), new RoutingInfo(routingInfo.getCurrentAllocations(), entry.getValue().intValue(), state, reason));
                }
            }
            empty2.calculateAndSetAssignmentState();
            Optional<String> explainAssignments = explainAssignments(assignmentPlan, this.nodeLoads, deployment);
            Objects.requireNonNull(empty2);
            explainAssignments.ifPresent(empty2::setReason);
            empty.addNewAssignment(deployment.id(), empty2);
        }
        return empty;
    }

    private Optional<String> explainAssignments(AssignmentPlan assignmentPlan, Map<DiscoveryNode, NodeLoad> map, AssignmentPlan.Deployment deployment) {
        if (assignmentPlan.satisfiesAllocations(deployment)) {
            return Optional.empty();
        }
        if (map.isEmpty()) {
            return Optional.of("No ML nodes exist in the cluster");
        }
        TreeMap treeMap = new TreeMap();
        for (Map.Entry<DiscoveryNode, NodeLoad> entry : map.entrySet()) {
            explainAssignment(assignmentPlan, entry.getKey(), entry.getValue(), deployment).ifPresent(str -> {
                treeMap.put(((DiscoveryNode) entry.getKey()).getId(), str);
            });
        }
        return !treeMap.isEmpty() ? Optional.of((String) treeMap.entrySet().stream().map(entry2 -> {
            return Strings.format("Could not assign (more) allocations on node [%s]. Reason: %s", new Object[]{entry2.getKey(), entry2.getValue()});
        }).collect(Collectors.joining("|"))) : Optional.empty();
    }

    private Optional<String> explainAssignment(AssignmentPlan assignmentPlan, DiscoveryNode discoveryNode, NodeLoad nodeLoad, AssignmentPlan.Deployment deployment) {
        if (!org.elasticsearch.common.Strings.isNullOrEmpty(nodeLoad.getError())) {
            return Optional.of(nodeLoad.getError());
        }
        if (deployment.memoryBytes() <= assignmentPlan.getRemainingNodeMemory(discoveryNode.getId())) {
            return deployment.threadsPerAllocation() > assignmentPlan.getRemainingNodeCores(discoveryNode.getId()) ? Optional.of(ParameterizedMessage.format("This node has insufficient allocated processors. Available processors [{}], free processors [{}], processors required for each allocation of this model [{}]", new Object[]{Integer.valueOf(MlProcessors.get(discoveryNode, Integer.valueOf(this.allocatedProcessorsScale)).roundUp()), Integer.valueOf(assignmentPlan.getRemainingNodeCores(discoveryNode.getId())), Integer.valueOf(deployment.threadsPerAllocation())})) : Optional.empty();
        }
        boolean z = nodeLoad.getNumAssignedJobsAndModels() > 0 || assignmentPlan.getRemainingNodeCores(nodeLoad.getNodeId()) < MlProcessors.get(discoveryNode, Integer.valueOf(this.allocatedProcessorsScale)).roundUp();
        long memoryBytes = deployment.memoryBytes() + (z ? 0L : MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes());
        long remainingNodeMemory = assignmentPlan.getRemainingNodeMemory(discoveryNode.getId()) + (z ? 0L : MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes());
        return Optional.of(ParameterizedMessage.format("This node has insufficient available memory. Available memory for ML [{} ({})], free memory [{} ({})], estimated memory required for this model [{} ({})].", new Object[]{Long.valueOf(nodeLoad.getMaxMlMemory()), ByteSizeValue.ofBytes(nodeLoad.getMaxMlMemory()).toString(), Long.valueOf(remainingNodeMemory), ByteSizeValue.ofBytes(remainingNodeMemory).toString(), Long.valueOf(memoryBytes), ByteSizeValue.ofBytes(memoryBytes).toString()}));
    }
}
