package org.elasticsearch.xpack.ml.inference.pytorch.process;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.ProcessPipes;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.class */
public class PyTorchBuilder {
    public static final String PROCESS_NAME = "pytorch_inference";
    private static final String PROCESS_PATH = "./pytorch_inference";
    private static final String LICENSE_KEY_VALIDATED_ARG = "--validElasticLicenseKeyConfirmed=";
    private static final String NUM_THREADS_PER_ALLOCATION_ARG = "--numThreadsPerAllocation=";
    private static final String NUM_ALLOCATIONS_ARG = "--numAllocations=";
    private static final String CACHE_MEMORY_LIMIT_BYTES_ARG = "--cacheMemorylimitBytes=";
    private static final String LOW_PRIORITY_ARG = "--lowPriority";
    private final NativeController nativeController;
    private final ProcessPipes processPipes;
    private final StartTrainedModelDeploymentAction.TaskParams taskParams;

    public PyTorchBuilder(NativeController nativeController, ProcessPipes processPipes, StartTrainedModelDeploymentAction.TaskParams taskParams) {
        this.nativeController = (NativeController) Objects.requireNonNull(nativeController);
        this.processPipes = (ProcessPipes) Objects.requireNonNull(processPipes);
        this.taskParams = (StartTrainedModelDeploymentAction.TaskParams) Objects.requireNonNull(taskParams);
    }

    public void build() throws IOException, InterruptedException {
        List<String> buildCommand = buildCommand();
        this.processPipes.addArgs(buildCommand);
        this.nativeController.startProcess(buildCommand);
    }

    private List<String> buildCommand() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(PROCESS_PATH);
        arrayList.add("--validElasticLicenseKeyConfirmed=true");
        arrayList.add("--numThreadsPerAllocation=" + this.taskParams.getThreadsPerAllocation());
        arrayList.add("--numAllocations=" + this.taskParams.getNumberOfAllocations());
        if (this.taskParams.getCacheSizeBytes() > 0) {
            arrayList.add("--cacheMemorylimitBytes=" + this.taskParams.getCacheSizeBytes());
        }
        if (this.taskParams.getPriority() == Priority.LOW) {
            arrayList.add(LOW_PRIORITY_ARG);
        }
        return arrayList;
    }
}
