package org.elasticsearch.xpack.ml.inference.assignment.planning;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding.class */
public class RandomizedAssignmentRounding {
    private static final Logger logger = LogManager.getLogger(RandomizedAssignmentRounding.class);
    private static final double EPS = 1.0E-6d;
    private final Random random;
    private final int rounds;
    private final Collection<AssignmentPlan.Node> nodes;
    private final Collection<AssignmentPlan.Deployment> deployments;
    private final AssignmentHolder assignmentHolder;

    /* loaded from: input_file:org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding$AssignmentHolder.class */
    private class AssignmentHolder {
        private final Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> assignments = new HashMap();
        private final Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> allocations = new HashMap();
        private final ResourceTracker resourceTracker;

        private AssignmentHolder() {
            this.resourceTracker = new ResourceTracker(RandomizedAssignmentRounding.this.nodes, RandomizedAssignmentRounding.this.deployments);
        }

        private AssignmentHolder(AssignmentHolder assignmentHolder) {
            this.assignments.putAll(assignmentHolder.assignments);
            this.allocations.putAll(assignmentHolder.allocations);
            this.resourceTracker = new ResourceTracker(assignmentHolder.resourceTracker);
        }

        private void initializeAssignments(Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> map, Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> map2) {
            for (AssignmentPlan.Node node : RandomizedAssignmentRounding.this.nodes) {
                for (AssignmentPlan.Deployment deployment : RandomizedAssignmentRounding.this.deployments) {
                    Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node> tuple = Tuple.tuple(deployment, node);
                    double doubleValue = map2.get(tuple).doubleValue();
                    double doubleValue2 = map.get(tuple).doubleValue();
                    if (doubleValue == 1.0d && RandomizedAssignmentRounding.isInteger(doubleValue2)) {
                        this.resourceTracker.assign(deployment, node, (int) Math.rint(doubleValue2));
                    }
                    this.assignments.put(tuple, Double.valueOf(doubleValue));
                    this.allocations.put(tuple, Double.valueOf(doubleValue2));
                }
            }
        }

        private void assignUnderSubscribedNodes() {
            assignUnderSubscribedNodes(RandomizedAssignmentRounding.this.nodes);
        }

        private void assignUnderSubscribedNodes(Collection<AssignmentPlan.Node> collection) {
            for (AssignmentPlan.Node node : collection.stream().sorted(Comparator.comparingDouble(this::decreasingQualityNodeOrder)).toList()) {
                ArrayList<AssignmentPlan.Deployment> arrayList = new ArrayList();
                long j = 0;
                int i = 0;
                for (AssignmentPlan.Deployment deployment : RandomizedAssignmentRounding.this.deployments) {
                    Tuple tuple = Tuple.tuple(deployment, node);
                    if (this.assignments.get(tuple).doubleValue() > 0.0d) {
                        int ceil = (int) Math.ceil(this.allocations.get(tuple).doubleValue());
                        j += deployment.estimateMemoryUsageBytes(ceil);
                        i += ceil * deployment.threadsPerAllocation();
                        arrayList.add(deployment);
                    }
                }
                if (j <= node.availableMemoryBytes() && i <= node.cores()) {
                    for (AssignmentPlan.Deployment deployment2 : arrayList) {
                        Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node> tuple2 = Tuple.tuple(deployment2, node);
                        if (this.assignments.get(tuple2).doubleValue() > 0.0d && this.assignments.get(tuple2).doubleValue() < 1.0d) {
                            assignModelToNode(deployment2, node, allocationsToAssign(tuple2));
                        }
                    }
                    assignExcessCores(node);
                }
            }
        }

        private int allocationsToAssign(Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node> tuple) {
            return RandomizedAssignmentRounding.isInteger(this.allocations.get(tuple).doubleValue()) ? (int) Math.rint(this.allocations.get(tuple).doubleValue()) : (int) Math.ceil(this.allocations.get(tuple).doubleValue());
        }

        private void assignModelToNode(AssignmentPlan.Deployment deployment, AssignmentPlan.Node node, int i) {
            Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node> tuple = Tuple.tuple(deployment, node);
            int min = Math.min(i, this.resourceTracker.remainingModelAllocations.get(deployment).intValue());
            this.assignments.put(tuple, Double.valueOf(1.0d));
            this.allocations.put(tuple, Double.valueOf(min));
            this.resourceTracker.assign(deployment, node, min);
        }

        private double decreasingQualityNodeOrder(AssignmentPlan.Node node) {
            double d = 0.0d;
            for (AssignmentPlan.Deployment deployment : RandomizedAssignmentRounding.this.deployments) {
                Tuple tuple = Tuple.tuple(deployment, node);
                if (this.allocations.get(tuple).doubleValue() > 0.0d) {
                    d += (1 + (deployment.currentAllocationsByNodeId().containsKey(node.id()) ? 1 : 0)) * this.allocations.get(tuple).doubleValue() * deployment.threadsPerAllocation();
                }
            }
            return d;
        }

        private void assignExcessCores(AssignmentPlan.Node node) {
            if (this.resourceTracker.remainingNodeCores.get(node).intValue() == 0 || hasSoftAssignments(node)) {
                return;
            }
            for (AssignmentPlan.Deployment deployment : RandomizedAssignmentRounding.this.deployments.stream().filter(deployment2 -> {
                return this.assignments.get(Tuple.tuple(deployment2, node)).doubleValue() == 1.0d && this.resourceTracker.remainingModelAllocations.get(deployment2).intValue() > 0;
            }).sorted(Comparator.comparingDouble(AssignmentHolder::remainingModelOrder)).toList()) {
                if (this.resourceTracker.remainingNodeCores.get(node).intValue() <= 0) {
                    break;
                }
                int findExcessAllocations = deployment.findExcessAllocations(Math.min(this.resourceTracker.remainingNodeCores.get(node).intValue() / deployment.threadsPerAllocation(), this.resourceTracker.remainingModelAllocations.get(deployment).intValue()), this.resourceTracker.remainingNodeMemory.get(node).longValue());
                this.allocations.compute(Tuple.tuple(deployment, node), (tuple, d) -> {
                    return Double.valueOf(d.doubleValue() + findExcessAllocations);
                });
                this.resourceTracker.assign(deployment, node, findExcessAllocations);
            }
            zeroSoftAssignmentsOfSatisfiedModels();
        }

        private static double remainingModelOrder(AssignmentPlan.Deployment deployment) {
            return (deployment.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * (-deployment.minimumMemoryRequiredBytes());
        }

        private boolean hasSoftAssignments(AssignmentPlan.Node node) {
            return RandomizedAssignmentRounding.this.deployments.stream().anyMatch(deployment -> {
                return isSoftAssignment(deployment, node);
            });
        }

        private boolean isSoftAssignment(AssignmentPlan.Deployment deployment, AssignmentPlan.Node node) {
            Tuple tuple = Tuple.tuple(deployment, node);
            return (this.assignments.get(tuple).doubleValue() > 0.0d && this.assignments.get(tuple).doubleValue() < 1.0d) || !RandomizedAssignmentRounding.isInteger(this.allocations.get(tuple).doubleValue());
        }

        private void zeroSoftAssignmentsOfSatisfiedModels() {
            for (AssignmentPlan.Deployment deployment : RandomizedAssignmentRounding.this.deployments) {
                if (this.resourceTracker.remainingModelAllocations.get(deployment).intValue() <= 0) {
                    for (AssignmentPlan.Node node : RandomizedAssignmentRounding.this.nodes) {
                        if (isSoftAssignment(deployment, node)) {
                            unassign(Tuple.tuple(deployment, node));
                        }
                    }
                }
            }
        }

        private void unassign(Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node> tuple) {
            this.assignments.put(tuple, Double.valueOf(0.0d));
            this.allocations.put(tuple, Double.valueOf(0.0d));
        }

        private List<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>> createSoftAssignmentQueue() {
            ArrayList arrayList = new ArrayList();
            RandomizedAssignmentRounding.this.deployments.forEach(deployment -> {
                RandomizedAssignmentRounding.this.nodes.forEach(node -> {
                    if (isSoftAssignment(deployment, node)) {
                        arrayList.add(Tuple.tuple(deployment, node));
                    }
                });
            });
            arrayList.sort(Comparator.comparingDouble(this::assignmentDistanceFromZeroOrOneOrder).thenComparingDouble(this::assignmentMostRemainingThreadsOrder));
            return arrayList;
        }

        private double assignmentDistanceFromZeroOrOneOrder(Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node> tuple) {
            return Math.min(this.assignments.get(tuple).doubleValue(), 1.0d - this.assignments.get(tuple).doubleValue());
        }

        private double assignmentMostRemainingThreadsOrder(Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node> tuple) {
            return (-this.allocations.get(tuple).doubleValue()) * ((AssignmentPlan.Deployment) tuple.v1()).threadsPerAllocation();
        }

        private void doRandomizedRounding(List<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>> list) {
            for (Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node> tuple : list) {
                if (isSoftAssignment((AssignmentPlan.Deployment) tuple.v1(), (AssignmentPlan.Node) tuple.v2())) {
                    AssignmentPlan.Deployment deployment = (AssignmentPlan.Deployment) tuple.v1();
                    AssignmentPlan.Node node = (AssignmentPlan.Node) tuple.v2();
                    int ceil = RandomizedAssignmentRounding.this.random.nextDouble() < this.allocations.get(tuple).doubleValue() - Math.floor(this.allocations.get(tuple).doubleValue()) ? (int) Math.ceil(this.allocations.get(tuple).doubleValue()) : (int) Math.floor(this.allocations.get(tuple).doubleValue());
                    if (deployment.estimateMemoryUsageBytes(ceil) > this.resourceTracker.remainingNodeMemory.get(node).longValue() || deployment.threadsPerAllocation() > this.resourceTracker.remainingNodeCores.get(node).intValue() || ceil == 0 || RandomizedAssignmentRounding.this.random.nextDouble() > this.assignments.get(tuple).doubleValue()) {
                        unassign(tuple);
                        assignUnderSubscribedNodes(Set.of(node));
                    } else {
                        assignModelToNode(deployment, node, deployment.findOptimalAllocations(Math.min(ceil, this.resourceTracker.remainingNodeCores.get(node).intValue() / deployment.threadsPerAllocation()), this.resourceTracker.remainingNodeMemory.get(node).longValue()));
                        unassignOversizedModels(node);
                        assignExcessCores(node);
                    }
                }
            }
        }

        private void unassignOversizedModels(AssignmentPlan.Node node) {
            for (AssignmentPlan.Deployment deployment : RandomizedAssignmentRounding.this.deployments) {
                Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node> tuple = Tuple.tuple(deployment, node);
                if (this.assignments.get(tuple).doubleValue() < 1.0d && deployment.minimumMemoryRequiredBytes() > this.resourceTracker.remainingNodeMemory.get(node).longValue()) {
                    unassign(tuple);
                }
            }
        }

        private AssignmentPlan toPlan() {
            AssignmentPlan.Builder builder = AssignmentPlan.builder(RandomizedAssignmentRounding.this.nodes, RandomizedAssignmentRounding.this.deployments);
            for (Map.Entry<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Integer> entry : tryAssigningRemainingCores().entrySet()) {
                if (builder.canAssign((AssignmentPlan.Deployment) entry.getKey().v1(), (AssignmentPlan.Node) entry.getKey().v2(), entry.getValue().intValue())) {
                    builder.assignModelToNode((AssignmentPlan.Deployment) entry.getKey().v1(), (AssignmentPlan.Node) entry.getKey().v2(), entry.getValue().intValue());
                }
            }
            return builder.build();
        }

        private Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Integer> tryAssigningRemainingCores() {
            HashMap hashMap = new HashMap();
            ResourceTracker resourceTracker = new ResourceTracker(RandomizedAssignmentRounding.this.nodes, RandomizedAssignmentRounding.this.deployments);
            for (AssignmentPlan.Deployment deployment : RandomizedAssignmentRounding.this.deployments) {
                for (AssignmentPlan.Node node : RandomizedAssignmentRounding.this.nodes) {
                    Tuple tuple = Tuple.tuple(deployment, node);
                    int floor = (int) Math.floor(this.allocations.getOrDefault(tuple, Double.valueOf(0.0d)).doubleValue());
                    hashMap.put(tuple, Integer.valueOf(floor));
                    if (floor > 0) {
                        resourceTracker.assign(deployment, node, floor);
                    }
                }
            }
            for (AssignmentPlan.Deployment deployment2 : RandomizedAssignmentRounding.this.deployments.stream().filter(deployment3 -> {
                return resourceTracker.remainingModelAllocations.get(deployment3).intValue() > 0;
            }).sorted(Comparator.comparingDouble(AssignmentHolder::remainingModelOrder)).toList()) {
                for (AssignmentPlan.Node node2 : RandomizedAssignmentRounding.this.nodes.stream().filter(node3 -> {
                    return resourceTracker.remainingNodeMemory.get(node3).longValue() >= deployment2.minimumMemoryRequiredBytes() && resourceTracker.remainingNodeCores.get(node3).intValue() >= deployment2.threadsPerAllocation() && ((Integer) hashMap.get(Tuple.tuple(deployment2, node3))).intValue() == 0;
                }).sorted(Comparator.comparingDouble(node4 -> {
                    return remainingNodeOrder(node4, deployment2, resourceTracker.remainingNodeCores.get(node4).intValue(), resourceTracker.remainingNodeMemory.get(node4).longValue(), resourceTracker.remainingModelAllocations.get(deployment2).intValue());
                })).toList()) {
                    int min = Math.min(resourceTracker.remainingNodeCores.get(node2).intValue() / deployment2.threadsPerAllocation(), Math.min(resourceTracker.remainingModelAllocations.get(deployment2).intValue(), deployment2.findOptimalAllocations(resourceTracker.remainingNodeCores.get(node2).intValue() / deployment2.threadsPerAllocation(), resourceTracker.remainingModelAllocations.get(deployment2).intValue())));
                    resourceTracker.assign(deployment2, node2, min);
                    hashMap.put(Tuple.tuple(deployment2, node2), Integer.valueOf(min));
                    if (resourceTracker.remainingModelAllocations.get(deployment2).intValue() == 0) {
                        break;
                    }
                }
            }
            return hashMap;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static double remainingNodeOrder(AssignmentPlan.Node node, AssignmentPlan.Deployment deployment, int i, long j, int i2) {
            return (deployment.currentAllocationsByNodeId().containsKey(node.id()) ? 0 : 1) + (i <= i2 * deployment.threadsPerAllocation() ? 0.0d : 0.5d) + (0.01d * RandomizedAssignmentRounding.distance(i, i2 * deployment.threadsPerAllocation())) + (0.01d * j);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/elasticsearch/xpack/ml/inference/assignment/planning/RandomizedAssignmentRounding$ResourceTracker.class */
    public static class ResourceTracker {
        final Set<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>> assignments = new HashSet();
        final Map<AssignmentPlan.Node, Long> remainingNodeMemory;
        final Map<AssignmentPlan.Node, Integer> remainingNodeCores;
        final Map<AssignmentPlan.Deployment, Integer> remainingModelAllocations;

        ResourceTracker(Collection<AssignmentPlan.Node> collection, Collection<AssignmentPlan.Deployment> collection2) {
            this.remainingNodeMemory = Maps.newHashMapWithExpectedSize(collection.size());
            this.remainingNodeCores = Maps.newHashMapWithExpectedSize(collection.size());
            this.remainingModelAllocations = Maps.newHashMapWithExpectedSize(collection2.size());
            collection.forEach(node -> {
                this.remainingNodeMemory.put(node, Long.valueOf(node.availableMemoryBytes()));
                this.remainingNodeCores.put(node, Integer.valueOf(node.cores()));
            });
            for (AssignmentPlan.Deployment deployment : collection2) {
                for (AssignmentPlan.Node node2 : collection) {
                    if (deployment.currentAllocationsByNodeId().containsKey(node2.id())) {
                        this.assignments.add(Tuple.tuple(deployment, node2));
                    }
                }
                this.remainingModelAllocations.put(deployment, Integer.valueOf(deployment.allocations()));
            }
        }

        ResourceTracker(ResourceTracker resourceTracker) {
            this.assignments.addAll(resourceTracker.assignments);
            this.remainingNodeMemory = new HashMap(resourceTracker.remainingNodeMemory);
            this.remainingNodeCores = new HashMap(resourceTracker.remainingNodeCores);
            this.remainingModelAllocations = new HashMap(resourceTracker.remainingModelAllocations);
        }

        void assign(AssignmentPlan.Deployment deployment, AssignmentPlan.Node node, int i) {
            if (!this.assignments.contains(Tuple.tuple(deployment, node))) {
                this.assignments.add(Tuple.tuple(deployment, node));
                this.remainingNodeMemory.compute(node, (node2, l) -> {
                    return Long.valueOf(l.longValue() - deployment.estimateMemoryUsageBytes(i));
                });
            }
            this.remainingNodeCores.compute(node, (node3, num) -> {
                return Integer.valueOf(num.intValue() - (i * deployment.threadsPerAllocation()));
            });
            this.remainingModelAllocations.compute(deployment, (deployment2, num2) -> {
                return Integer.valueOf(num2.intValue() - i);
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public RandomizedAssignmentRounding(Random random, int i, Collection<AssignmentPlan.Node> collection, Collection<AssignmentPlan.Deployment> collection2) {
        if (i <= 0) {
            throw new IllegalArgumentException("rounds must be > 0");
        }
        this.random = (Random) Objects.requireNonNull(random);
        this.rounds = i;
        this.nodes = (Collection) Objects.requireNonNull(collection);
        this.deployments = (Collection) Objects.requireNonNull(collection2);
        this.assignmentHolder = new AssignmentHolder();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public AssignmentPlan computePlan(Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> map, Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> map2) {
        AssignmentPlan plan = this.assignmentHolder.toPlan();
        this.assignmentHolder.initializeAssignments(map, map2);
        this.assignmentHolder.assignUnderSubscribedNodes();
        List<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>> createSoftAssignmentQueue = this.assignmentHolder.createSoftAssignmentQueue();
        if (createSoftAssignmentQueue.isEmpty()) {
            AssignmentPlan plan2 = this.assignmentHolder.toPlan();
            if (plan2.compareTo(plan) > 0) {
                plan = plan2;
            }
        } else {
            logger.debug(() -> {
                return "Random assignment rounding across [" + this.rounds + "] rounds";
            });
            for (int i = 0; i < this.rounds; i++) {
                AssignmentHolder assignmentHolder = new AssignmentHolder(this.assignmentHolder);
                assignmentHolder.doRandomizedRounding(createSoftAssignmentQueue);
                AssignmentPlan plan3 = assignmentHolder.toPlan();
                if (plan3.compareTo(plan) > 0) {
                    plan = plan3;
                }
            }
        }
        return plan;
    }

    @SuppressForbidden(reason = "Math#abs(int) is safe here as we protect against MIN_VALUE")
    private static int distance(int i, int i2) {
        int i3 = i - i2;
        if (i3 == Integer.MIN_VALUE) {
            return Integer.MAX_VALUE;
        }
        return Math.abs(i3);
    }

    private static boolean isInteger(double d) {
        return Double.isFinite(d) && Math.abs(d - Math.rint(d)) < EPS;
    }
}
