package org.elasticsearch.xpack.ml.inference.deployment;

import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.TemporalAmount;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.TransportSearchAction;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.IdsQueryBuilder;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
import org.elasticsearch.xpack.ml.inference.pytorch.PriorityProcessWorkerExecutorService;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcess;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchStateStreamer;
import org.elasticsearch.xpack.ml.inference.pytorch.results.ThreadSettings;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.class */
public class DeploymentManager {
    private static final Logger logger;
    private static final AtomicLong requestIdCounter;
    public static final int NUM_RESTART_ATTEMPTS = 3;
    private final Client client;
    private final NamedXContentRegistry xContentRegistry;
    private final PyTorchProcessFactory pyTorchProcessFactory;
    private final ExecutorService executorServiceForDeployment;
    private final ExecutorService executorServiceForProcess;
    private final ThreadPool threadPool;
    private final InferenceAuditor inferenceAuditor;
    private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap();
    private final int maxProcesses;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager$ProcessContext.class */
    public class ProcessContext {
        private static final String PROCESS_NAME = "inference process";
        private static final TimeValue COMPLETION_TIMEOUT;
        private final TrainedModelDeploymentTask task;
        private final PyTorchResultProcessor resultProcessor;
        private final PyTorchStateStreamer stateStreamer;
        private final PriorityProcessWorkerExecutorService priorityProcessWorker;
        private volatile Instant startTime;
        private volatile Integer numThreadsPerAllocation;
        private volatile Integer numAllocations;
        private volatile boolean isStopped;
        static final /* synthetic */ boolean $assertionsDisabled;
        private final SetOnce<PyTorchProcess> process = new SetOnce<>();
        private final SetOnce<NlpTask.Processor> nlpTaskProcessor = new SetOnce<>();
        private final SetOnce<TrainedModelInput> modelInput = new SetOnce<>();
        private final SetOnce<TrainedModelPrefixStrings> prefixes = new SetOnce<>();
        private final AtomicInteger rejectedExecutionCount = new AtomicInteger();
        private final AtomicInteger timeoutCount = new AtomicInteger();
        private final AtomicInteger startsCount = new AtomicInteger();

        ProcessContext(TrainedModelDeploymentTask trainedModelDeploymentTask, Integer num) {
            this.task = (TrainedModelDeploymentTask) Objects.requireNonNull(trainedModelDeploymentTask);
            this.resultProcessor = new PyTorchResultProcessor(trainedModelDeploymentTask.getDeploymentId(), threadSettings -> {
                this.numThreadsPerAllocation = Integer.valueOf(threadSettings.numThreadsPerAllocation());
                this.numAllocations = Integer.valueOf(threadSettings.numAllocations());
            });
            this.stateStreamer = new PyTorchStateStreamer(DeploymentManager.this.client, DeploymentManager.this.executorServiceForProcess, DeploymentManager.this.xContentRegistry);
            this.priorityProcessWorker = new PriorityProcessWorkerExecutorService(DeploymentManager.this.threadPool.getThreadContext(), PROCESS_NAME, trainedModelDeploymentTask.getParams().getQueueCapacity());
            this.startsCount.set(num == null ? 1 : num.intValue());
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public PyTorchResultProcessor getResultProcessor() {
            return this.resultProcessor;
        }

        synchronized void startAndLoad(TrainedModelLocation trainedModelLocation, ActionListener<Boolean> actionListener) {
            if (!$assertionsDisabled && !Thread.currentThread().getName().contains(MachineLearning.UTILITY_THREAD_POOL_NAME)) {
                throw new AssertionError(Strings.format("Must execute from [%s] but thread is [%s]", new Object[]{MachineLearning.UTILITY_THREAD_POOL_NAME, Thread.currentThread().getName()}));
            }
            if (this.isStopped) {
                DeploymentManager.logger.debug("[{}] model stopped before it is started", this.task.getDeploymentId());
                actionListener.onFailure(new IllegalArgumentException("model stopped before it is started"));
                return;
            }
            DeploymentManager.logger.debug("[{}] start and load", this.task.getDeploymentId());
            this.process.set(DeploymentManager.this.pyTorchProcessFactory.createProcess(this.task, DeploymentManager.this.executorServiceForProcess, () -> {
                this.resultProcessor.awaitCompletion(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES);
            }, onProcessCrashHandleRestarts(this.startsCount, this.task.getDeploymentId())));
            this.startTime = Instant.now();
            DeploymentManager.logger.debug("[{}] process started", this.task.getDeploymentId());
            try {
                loadModel(trainedModelLocation, actionListener.delegateFailureAndWrap((actionListener2, bool) -> {
                    if (this.isStopped) {
                        DeploymentManager.logger.debug("[{}] model loaded but process is stopped", this.task.getDeploymentId());
                        killProcessIfPresent();
                        actionListener2.onFailure(new IllegalStateException("model loaded but process is stopped"));
                    } else {
                        DeploymentManager.logger.debug("[{}] model loaded, starting priority process worker thread", this.task.getDeploymentId());
                        startPriorityProcessWorker();
                        actionListener2.onResponse(bool);
                    }
                }));
            } catch (Exception e) {
                actionListener.onFailure(e);
            }
        }

        private Consumer<String> onProcessCrashHandleRestarts(AtomicInteger atomicInteger, String str) {
            return str2 -> {
                if (isThisProcessOlderThan1Day()) {
                    atomicInteger.set(1);
                    DeploymentManager.logger.error("[" + this.task.getDeploymentId() + "] inference process crashed due to reason [" + str2 + "]. This process was started more than 24 hours ago; the starts count is reset to 1.");
                } else {
                    DeploymentManager.logger.error("[{}] inference process crashed due to reason [{}]", this.task.getDeploymentId(), str2);
                }
                DeploymentManager.this.processContextByAllocation.remove(Long.valueOf(this.task.getId()));
                this.isStopped = true;
                this.resultProcessor.stop();
                this.stateStreamer.cancel();
                if (atomicInteger.get() > 3) {
                    finishClosingProcess(atomicInteger, str2, str);
                    return;
                }
                String str2 = "Inference process [" + this.task.getDeploymentId() + "] failed due to [" + str2 + "]. This is the [" + atomicInteger.get() + "] failure in 24 hours, and the process will be restarted.";
                DeploymentManager.logger.info(str2);
                DeploymentManager.this.threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
                    DeploymentManager.this.inferenceAuditor.warning(str, str2);
                });
                this.priorityProcessWorker.shutdownNow();
                DeploymentManager.this.startDeployment(this.task, Integer.valueOf(atomicInteger.incrementAndGet()), ActionListener.wrap(trainedModelDeploymentTask -> {
                    DeploymentManager.logger.debug("Completed restart of inference process, the [{}] start", atomicInteger);
                }, exc -> {
                    finishClosingProcess(atomicInteger, "Failed to restart inference process because of error [" + exc.getMessage() + "]", str);
                }));
            };
        }

        private boolean isThisProcessOlderThan1Day() {
            return this.startTime.isBefore(Instant.now().minus((TemporalAmount) Duration.ofDays(1L)));
        }

        private void finishClosingProcess(AtomicInteger atomicInteger, String str, String str2) {
            String str3 = "[" + this.task.getDeploymentId() + "] inference process failed after [" + atomicInteger.get() + "] starts in 24 hours, not restarting again.";
            DeploymentManager.logger.warn(str3);
            DeploymentManager.this.threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
                DeploymentManager.this.inferenceAuditor.error(str2, str3);
            });
            this.priorityProcessWorker.shutdownNowWithError(new IllegalStateException(str));
            if (this.nlpTaskProcessor.get() != null) {
                ((NlpTask.Processor) this.nlpTaskProcessor.get()).close();
            }
            this.task.setFailed("inference process crashed due to reason [" + str + "]");
        }

        void startPriorityProcessWorker() {
            ExecutorService executorService = DeploymentManager.this.executorServiceForProcess;
            PriorityProcessWorkerExecutorService priorityProcessWorkerExecutorService = this.priorityProcessWorker;
            Objects.requireNonNull(priorityProcessWorkerExecutorService);
            executorService.submit(priorityProcessWorkerExecutorService::start);
        }

        synchronized void forcefullyStopProcess() {
            DeploymentManager.logger.debug(() -> {
                return Strings.format("[%s] Forcefully stopping process", new Object[]{this.task.getDeploymentId()});
            });
            prepareInternalStateForShutdown();
            if (this.priorityProcessWorker.isShutdown()) {
                handleAlreadyShuttingDownWorker();
            } else {
                this.priorityProcessWorker.shutdown();
            }
            killProcessIfPresent();
            closeNlpTaskProcessor();
        }

        private void prepareInternalStateForShutdown() {
            this.isStopped = true;
            this.resultProcessor.stop();
            this.stateStreamer.cancel();
        }

        private void handleAlreadyShuttingDownWorker() {
            DeploymentManager.logger.debug(() -> {
                return Strings.format("[%s] Process worker was already marked for shutdown", new Object[]{this.task.getDeploymentId()});
            });
            this.priorityProcessWorker.notifyQueueRunnables();
        }

        private void killProcessIfPresent() {
            try {
                if (this.process.get() == null) {
                    return;
                }
                ((PyTorchProcess) this.process.get()).kill(true);
            } catch (IOException e) {
                DeploymentManager.logger.error(() -> {
                    return "[" + this.task.getDeploymentId() + "] Failed to kill process";
                }, e);
            }
        }

        private void closeNlpTaskProcessor() {
            if (this.nlpTaskProcessor.get() != null) {
                ((NlpTask.Processor) this.nlpTaskProcessor.get()).close();
            }
        }

        private synchronized void stopProcessAfterCompletingPendingWork() {
            DeploymentManager.logger.debug(() -> {
                return Strings.format("[%s] Stopping process after completing its pending work", new Object[]{this.task.getDeploymentId()});
            });
            prepareInternalStateForShutdown();
            if (this.priorityProcessWorker.isShutdown()) {
                handleAlreadyShuttingDownWorker();
            } else {
                signalAndWaitForWorkerTermination();
            }
            stopProcessGracefully();
            closeNlpTaskProcessor();
        }

        private void signalAndWaitForWorkerTermination() {
            try {
                awaitTerminationAfterCompletingWork();
            } catch (TimeoutException e) {
                DeploymentManager.logger.warn(Strings.format("[%s] Timed out waiting for process worker to complete, forcing a shutdown", new Object[]{this.task.getDeploymentId()}), e);
                this.priorityProcessWorker.shutdown();
                this.priorityProcessWorker.notifyQueueRunnables();
            }
        }

        private void awaitTerminationAfterCompletingWork() throws TimeoutException {
            try {
                this.priorityProcessWorker.shutdown();
                if (this.priorityProcessWorker.awaitTermination(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES)) {
                } else {
                    throw new TimeoutException(org.elasticsearch.common.Strings.format("Timed out waiting for process worker to complete for process %s", new Object[]{PROCESS_NAME}));
                }
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                DeploymentManager.logger.info(org.elasticsearch.common.Strings.format("[%s] Interrupted waiting for process worker to complete", new Object[]{PROCESS_NAME}));
            }
        }

        private void stopProcessGracefully() {
            try {
                closeProcessIfPresent();
                this.resultProcessor.awaitCompletion(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES);
            } catch (TimeoutException e) {
                DeploymentManager.logger.warn(Strings.format("[%s] Timed out waiting for results processor to stop", new Object[]{this.task.getDeploymentId()}), e);
            }
        }

        private void closeProcessIfPresent() {
            try {
                if (this.process.get() == null) {
                    return;
                }
                ((PyTorchProcess) this.process.get()).close();
            } catch (IOException e) {
                DeploymentManager.logger.error(Strings.format("[%s] Failed to stop process gracefully, attempting to kill it", new Object[]{this.task.getDeploymentId()}), e);
                killProcessIfPresent();
            }
        }

        void loadModel(TrainedModelLocation trainedModelLocation, ActionListener<Boolean> actionListener) {
            if (this.isStopped) {
                actionListener.onFailure(new IllegalArgumentException("Process has stopped, model loading canceled"));
            } else if (trainedModelLocation instanceof IndexLocation) {
                ((PyTorchProcess) this.process.get()).loadModel(this.task.getParams().getModelId(), ((IndexLocation) trainedModelLocation).getIndexName(), this.stateStreamer, ActionListener.wrap(bool -> {
                    DeploymentManager.this.executorServiceForDeployment.submit(() -> {
                        actionListener.onResponse(bool);
                    });
                }, exc -> {
                    DeploymentManager.this.executorServiceForDeployment.submit(() -> {
                        actionListener.onFailure(exc);
                    });
                }));
            } else {
                actionListener.onFailure(new IllegalStateException("unsupported trained model location [" + trainedModelLocation.getClass().getSimpleName() + "]"));
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public AtomicInteger getTimeoutCount() {
            return this.timeoutCount;
        }

        PriorityProcessWorkerExecutorService getPriorityProcessWorker() {
            return this.priorityProcessWorker;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public AtomicInteger getRejectedExecutionCount() {
            return this.rejectedExecutionCount;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public SetOnce<TrainedModelInput> getModelInput() {
            return this.modelInput;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public SetOnce<PyTorchProcess> getProcess() {
            return this.process;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public SetOnce<NlpTask.Processor> getNlpTaskProcessor() {
            return this.nlpTaskProcessor;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public SetOnce<TrainedModelPrefixStrings> getPrefixStrings() {
            return this.prefixes;
        }

        static {
            $assertionsDisabled = !DeploymentManager.class.desiredAssertionStatus();
            COMPLETION_TIMEOUT = TimeValue.timeValueMinutes(3L);
        }
    }

    public DeploymentManager(Client client, NamedXContentRegistry namedXContentRegistry, ThreadPool threadPool, PyTorchProcessFactory pyTorchProcessFactory, int i, InferenceAuditor inferenceAuditor) {
        this.client = (Client) Objects.requireNonNull(client);
        this.xContentRegistry = (NamedXContentRegistry) Objects.requireNonNull(namedXContentRegistry);
        this.pyTorchProcessFactory = (PyTorchProcessFactory) Objects.requireNonNull(pyTorchProcessFactory);
        this.threadPool = (ThreadPool) Objects.requireNonNull(threadPool);
        this.inferenceAuditor = (InferenceAuditor) Objects.requireNonNull(inferenceAuditor);
        this.executorServiceForDeployment = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME);
        this.executorServiceForProcess = threadPool.executor(MachineLearning.NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME);
        this.maxProcesses = i;
    }

    public Optional<ModelStats> getStats(TrainedModelDeploymentTask trainedModelDeploymentTask) {
        return Optional.ofNullable(this.processContextByAllocation.get(Long.valueOf(trainedModelDeploymentTask.getId()))).map(processContext -> {
            PyTorchResultProcessor.ResultStats resultStats = processContext.getResultProcessor().getResultStats();
            PyTorchResultProcessor.RecentStats recentStats = resultStats.recentStats();
            return new ModelStats(processContext.startTime, resultStats.timingStats().getCount(), Double.valueOf(resultStats.timingStats().getAverage()), Double.valueOf(resultStats.timingStatsExcludingCacheHits().getAverage()), resultStats.lastUsed(), processContext.priorityProcessWorker.queueSize() + resultStats.numberOfPendingResults(), resultStats.errorCount(), resultStats.cacheHitCount(), processContext.rejectedExecutionCount.intValue(), processContext.timeoutCount.intValue(), processContext.numThreadsPerAllocation, processContext.numAllocations, resultStats.peakThroughput(), recentStats.requestsProcessed(), recentStats.avgInferenceTime(), recentStats.cacheHitCount());
        });
    }

    ProcessContext addProcessContext(Long l, ProcessContext processContext) {
        return this.processContextByAllocation.putIfAbsent(l, processContext);
    }

    public void startDeployment(TrainedModelDeploymentTask trainedModelDeploymentTask, ActionListener<TrainedModelDeploymentTask> actionListener) {
        startDeployment(trainedModelDeploymentTask, null, actionListener);
    }

    public void startDeployment(TrainedModelDeploymentTask trainedModelDeploymentTask, Integer num, ActionListener<TrainedModelDeploymentTask> actionListener) {
        logger.info("[{}] Starting model deployment of model [{}]", trainedModelDeploymentTask.getDeploymentId(), trainedModelDeploymentTask.getModelId());
        if (this.processContextByAllocation.size() >= this.maxProcesses) {
            actionListener.onFailure(ExceptionsHelper.serverError("[{}] Could not start inference process as the node reached the max number [{}] of processes", new Object[]{trainedModelDeploymentTask.getDeploymentId(), Integer.valueOf(this.maxProcesses)}));
            return;
        }
        ProcessContext processContext = new ProcessContext(trainedModelDeploymentTask, num);
        if (addProcessContext(Long.valueOf(trainedModelDeploymentTask.getId()), processContext) != null) {
            actionListener.onFailure(ExceptionsHelper.serverError("[{}] Could not create inference process as one already exists", new Object[]{trainedModelDeploymentTask.getDeploymentId()}));
            return;
        }
        Objects.requireNonNull(actionListener);
        ActionListener wrap = ActionListener.wrap((v1) -> {
            r0.onResponse(v1);
        }, exc -> {
            ProcessContext remove = this.processContextByAllocation.remove(Long.valueOf(trainedModelDeploymentTask.getId()));
            if (remove != null) {
                remove.forcefullyStopProcess();
            }
            actionListener.onFailure(exc);
        });
        CheckedConsumer checkedConsumer = bool -> {
            this.executorServiceForProcess.execute(() -> {
                processContext.getResultProcessor().process((PyTorchProcess) processContext.process.get());
            });
            actionListener.onResponse(trainedModelDeploymentTask);
        };
        Objects.requireNonNull(wrap);
        ActionListener wrap2 = ActionListener.wrap(checkedConsumer, wrap::onFailure);
        CheckedConsumer checkedConsumer2 = trainedModelConfig -> {
            processContext.modelInput.set(trainedModelConfig.getInput());
            processContext.prefixes.set(trainedModelConfig.getPrefixStrings());
            NlpConfig inferenceConfig = trainedModelConfig.getInferenceConfig();
            if (!(inferenceConfig instanceof NlpConfig)) {
                wrap.onFailure(new IllegalArgumentException(Strings.format("[%s] must be a pytorch model; found inference config of kind [%s]", new Object[]{trainedModelConfig.getModelId(), trainedModelConfig.getInferenceConfig().getWriteableName()})));
                return;
            }
            NlpConfig nlpConfig = inferenceConfig;
            trainedModelDeploymentTask.init(nlpConfig);
            SearchRequest vocabSearchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), trainedModelConfig.getModelId());
            Client client = this.client;
            ActionType actionType = TransportSearchAction.TYPE;
            CheckedConsumer checkedConsumer3 = searchResponse -> {
                if (searchResponse.getHits().getHits().length == 0) {
                    wrap.onFailure(new ResourceNotFoundException(Messages.getMessage("Could not find vocabulary document [{1}] for trained model [{0}]", new Object[]{trainedModelConfig.getModelId(), VocabularyConfig.docId(trainedModelConfig.getModelId())}), new Object[0]));
                } else {
                    processContext.nlpTaskProcessor.set(new NlpTask(nlpConfig, parseVocabularyDocLeniently(searchResponse.getHits().getAt(0))).createProcessor());
                    this.executorServiceForDeployment.execute(new AbstractRunnable() { // from class: org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager.1
                        public void onFailure(Exception exc2) {
                            wrap.onFailure(exc2);
                        }

                        protected void doRun() {
                            processContext.startAndLoad(trainedModelConfig.getLocation(), wrap2);
                        }
                    });
                }
            };
            Objects.requireNonNull(wrap);
            ClientHelper.executeAsyncWithOrigin(client, "ml", actionType, vocabSearchRequest, ActionListener.wrap(checkedConsumer3, wrap::onFailure));
        };
        Objects.requireNonNull(wrap);
        ActionListener wrap3 = ActionListener.wrap(checkedConsumer2, wrap::onFailure);
        CheckedConsumer checkedConsumer3 = response -> {
            if (!$assertionsDisabled && response.getResources().results().size() != 1) {
                throw new AssertionError();
            }
            verifyMlNodesAndModelArchitectures((TrainedModelConfig) response.getResources().results().get(0), this.client, this.threadPool, wrap3);
        };
        Objects.requireNonNull(wrap);
        ClientHelper.executeAsyncWithOrigin(this.client, "ml", GetTrainedModelsAction.INSTANCE, new GetTrainedModelsAction.Request(trainedModelDeploymentTask.getParams().getModelId()), ActionListener.wrap(checkedConsumer3, wrap::onFailure));
    }

    void verifyMlNodesAndModelArchitectures(final TrainedModelConfig trainedModelConfig, Client client, ThreadPool threadPool, final ActionListener<TrainedModelConfig> actionListener) {
        callVerifyMlNodesAndModelArchitectures(trainedModelConfig, new ActionListener<TrainedModelConfig>() { // from class: org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager.2
            static final /* synthetic */ boolean $assertionsDisabled;

            public void onResponse(TrainedModelConfig trainedModelConfig2) {
                if (!$assertionsDisabled && !Objects.equals(trainedModelConfig2, trainedModelConfig)) {
                    throw new AssertionError();
                }
                actionListener.onResponse(trainedModelConfig);
            }

            public void onFailure(Exception exc) {
                actionListener.onFailure(exc);
            }

            static {
                $assertionsDisabled = !DeploymentManager.class.desiredAssertionStatus();
            }
        }, client, threadPool);
    }

    void callVerifyMlNodesAndModelArchitectures(TrainedModelConfig trainedModelConfig, ActionListener<TrainedModelConfig> actionListener, Client client, ThreadPool threadPool) {
        MlPlatformArchitecturesUtil.verifyMlNodesAndModelArchitectures(actionListener, client, threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME), trainedModelConfig);
    }

    private SearchRequest vocabSearchRequest(VocabularyConfig vocabularyConfig, String str) {
        return this.client.prepareSearch(new String[]{vocabularyConfig.getIndex()}).setQuery(new IdsQueryBuilder().addIds(new String[]{VocabularyConfig.docId(str)})).setSize(1).setTrackTotalHits(false).request();
    }

    Vocabulary parseVocabularyDocLeniently(SearchHit searchHit) throws IOException {
        try {
            XContentParser createParserNotCompressed = XContentHelper.createParserNotCompressed(LoggingDeprecationHandler.XCONTENT_PARSER_CONFIG.withRegistry(this.xContentRegistry), searchHit.getSourceRef(), XContentType.JSON);
            try {
                Vocabulary vocabulary = (Vocabulary) Vocabulary.PARSER.apply(createParserNotCompressed, (Object) null);
                if (createParserNotCompressed != null) {
                    createParserNotCompressed.close();
                }
                return vocabulary;
            } finally {
            }
        } catch (IOException e) {
            logger.error(() -> {
                return "failed to parse trained model vocabulary [" + searchHit.getId() + "]";
            }, e);
            throw e;
        }
    }

    public void stopDeployment(TrainedModelDeploymentTask trainedModelDeploymentTask) {
        ProcessContext remove = this.processContextByAllocation.remove(Long.valueOf(trainedModelDeploymentTask.getId()));
        if (remove == null) {
            logger.warn("[{}] No process context to stop", trainedModelDeploymentTask.getDeploymentId());
        } else {
            logger.info("[{}] Stopping deployment, reason [{}]", trainedModelDeploymentTask.getDeploymentId(), trainedModelDeploymentTask.stoppedReason().orElse("unknown"));
            remove.forcefullyStopProcess();
        }
    }

    public void stopAfterCompletingPendingWork(TrainedModelDeploymentTask trainedModelDeploymentTask) {
        ProcessContext remove = this.processContextByAllocation.remove(Long.valueOf(trainedModelDeploymentTask.getId()));
        if (remove == null) {
            logger.warn("[{}] No process context to stop gracefully", trainedModelDeploymentTask.getDeploymentId());
        } else {
            logger.info("[{}] Stopping deployment after completing pending tasks, reason [{}]", trainedModelDeploymentTask.getDeploymentId(), trainedModelDeploymentTask.stoppedReason().orElse("unknown"));
            remove.stopProcessAfterCompletingPendingWork();
        }
    }

    public void infer(TrainedModelDeploymentTask trainedModelDeploymentTask, InferenceConfig inferenceConfig, NlpInferenceInput nlpInferenceInput, boolean z, TimeValue timeValue, TrainedModelPrefixStrings.PrefixType prefixType, CancellableTask cancellableTask, boolean z2, ActionListener<InferenceResults> actionListener) {
        Objects.requireNonNull(actionListener);
        ProcessContext processContext = getProcessContext(trainedModelDeploymentTask, actionListener::onFailure);
        if (processContext == null) {
            return;
        }
        executePyTorchAction(processContext, z ? PriorityProcessWorkerExecutorService.RequestPriority.HIGH : PriorityProcessWorkerExecutorService.RequestPriority.NORMAL, new InferencePyTorchAction(trainedModelDeploymentTask.getDeploymentId(), requestIdCounter.getAndIncrement(), timeValue, processContext, inferenceConfig, nlpInferenceInput, prefixType, this.threadPool, cancellableTask, z2, actionListener));
    }

    public void updateNumAllocations(TrainedModelDeploymentTask trainedModelDeploymentTask, int i, TimeValue timeValue, ActionListener<ThreadSettings> actionListener) {
        Objects.requireNonNull(actionListener);
        ProcessContext processContext = getProcessContext(trainedModelDeploymentTask, actionListener::onFailure);
        if (processContext == null) {
            return;
        }
        executePyTorchAction(processContext, PriorityProcessWorkerExecutorService.RequestPriority.HIGHEST, new ThreadSettingsControlMessagePytorchAction(trainedModelDeploymentTask.getDeploymentId(), requestIdCounter.getAndIncrement(), i, timeValue, processContext, this.threadPool, actionListener));
    }

    public void clearCache(TrainedModelDeploymentTask trainedModelDeploymentTask, TimeValue timeValue, ActionListener<AcknowledgedResponse> actionListener) {
        Objects.requireNonNull(actionListener);
        ProcessContext processContext = getProcessContext(trainedModelDeploymentTask, actionListener::onFailure);
        if (processContext == null) {
            return;
        }
        executePyTorchAction(processContext, PriorityProcessWorkerExecutorService.RequestPriority.HIGHEST, new ClearCacheControlMessagePytorchAction(trainedModelDeploymentTask.getDeploymentId(), requestIdCounter.getAndIncrement(), timeValue, processContext, this.threadPool, actionListener.delegateFailureAndWrap((actionListener2, bool) -> {
            actionListener2.onResponse(AcknowledgedResponse.TRUE);
        })));
    }

    void executePyTorchAction(ProcessContext processContext, PriorityProcessWorkerExecutorService.RequestPriority requestPriority, AbstractPyTorchAction<?> abstractPyTorchAction) {
        try {
            processContext.getPriorityProcessWorker().executeWithPriority(abstractPyTorchAction, requestPriority, abstractPyTorchAction.getRequestId());
        } catch (EsRejectedExecutionException e) {
            processContext.getRejectedExecutionCount().incrementAndGet();
            abstractPyTorchAction.onFailure((Exception) e);
        } catch (Exception e2) {
            abstractPyTorchAction.onFailure(e2);
        }
    }

    private ProcessContext getProcessContext(TrainedModelDeploymentTask trainedModelDeploymentTask, Consumer<Exception> consumer) {
        if (trainedModelDeploymentTask.isStopped()) {
            consumer.accept(ExceptionsHelper.conflictStatusException("[{}] is stopping or stopped due to [{}]", new Object[]{trainedModelDeploymentTask.getDeploymentId(), trainedModelDeploymentTask.stoppedReason().orElse("")}));
            return null;
        }
        ProcessContext processContext = this.processContextByAllocation.get(Long.valueOf(trainedModelDeploymentTask.getId()));
        if (processContext != null) {
            return processContext;
        }
        consumer.accept(ExceptionsHelper.conflictStatusException("[{}] process context missing", new Object[]{trainedModelDeploymentTask.getDeploymentId()}));
        return null;
    }

    static {
        $assertionsDisabled = !DeploymentManager.class.desiredAssertionStatus();
        logger = LogManager.getLogger(DeploymentManager.class);
        requestIdCounter = new AtomicLong(1L);
    }
}
