package org.elasticsearch.xpack.ml.inference.pytorch.process;

import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.Locale;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.ml.inference.persistence.ChunkedTrainedModelRestorer;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.class */
public class PyTorchStateStreamer {
    private static final Logger logger = LogManager.getLogger(PyTorchStateStreamer.class);
    static final int NUM_BYTES_IN_PRELUDE = 4;
    static final long UNSIGNED_INT_MAX = 4294967295L;
    private final OriginSettingClient client;
    private final ExecutorService executorService;
    private final NamedXContentRegistry xContentRegistry;
    private volatile boolean isCancelled;
    private volatile long modelSize = -1;
    private final AtomicLong modelBytesWritten = new AtomicLong();

    public PyTorchStateStreamer(Client client, ExecutorService executorService, NamedXContentRegistry namedXContentRegistry) {
        this.client = new OriginSettingClient((Client) Objects.requireNonNull(client), "ml");
        this.executorService = (ExecutorService) Objects.requireNonNull(executorService);
        this.xContentRegistry = (NamedXContentRegistry) Objects.requireNonNull(namedXContentRegistry);
    }

    public void cancel() {
        this.isCancelled = true;
    }

    public void writeStateToStream(String str, String str2, OutputStream outputStream, ActionListener<Boolean> actionListener) {
        ChunkedTrainedModelRestorer chunkedTrainedModelRestorer = new ChunkedTrainedModelRestorer(str, this.client, this.executorService, this.xContentRegistry);
        chunkedTrainedModelRestorer.setSearchIndex(str2);
        chunkedTrainedModelRestorer.setSearchSize(1);
        CheckedFunction<TrainedModelDefinitionDoc, Boolean, IOException> checkedFunction = trainedModelDefinitionDoc -> {
            return Boolean.valueOf(writeChunk(trainedModelDefinitionDoc, outputStream));
        };
        Consumer<Boolean> consumer = bool -> {
            logger.debug("model [{}] state restored in [{}] documents from index [{}]", str, Integer.valueOf(chunkedTrainedModelRestorer.getNumDocsWritten()), str2);
            if (!bool.booleanValue()) {
                logger.info("[{}] loading model state cancelled", str);
            } else if (this.modelBytesWritten.get() != this.modelSize) {
                logger.error("model [{}] restored state size [{}] does not equal the expected model size [{}]", str, this.modelBytesWritten, Long.valueOf(this.modelSize));
            }
            actionListener.onResponse(bool);
        };
        Objects.requireNonNull(actionListener);
        chunkedTrainedModelRestorer.restoreModelDefinition(checkedFunction, consumer, actionListener::onFailure);
    }

    private boolean writeChunk(TrainedModelDefinitionDoc trainedModelDefinitionDoc, OutputStream outputStream) throws IOException {
        if (this.isCancelled) {
            return false;
        }
        if (this.modelSize == -1) {
            this.modelSize = writeModelSize(trainedModelDefinitionDoc.getModelId(), trainedModelDefinitionDoc.getTotalDefinitionLength(), outputStream);
        }
        outputStream.write(trainedModelDefinitionDoc.getBinaryData().array(), trainedModelDefinitionDoc.getBinaryData().arrayOffset(), trainedModelDefinitionDoc.getBinaryData().length());
        this.modelBytesWritten.addAndGet(trainedModelDefinitionDoc.getBinaryData().length());
        return true;
    }

    private static long writeModelSize(String str, Long l, OutputStream outputStream) throws IOException {
        if (l == null) {
            String format = String.format(Locale.ROOT, "The definition doc for model [%s] has a null value for field [%s]", str, TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName());
            logger.error(format);
            throw new IllegalStateException(format);
        }
        if (l.longValue() <= 0) {
            String format2 = String.format(Locale.ROOT, "The definition doc for model [%s] has a negative value [%s] for field [%s]", str, l, TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName());
            logger.error(format2);
            throw new IllegalStateException(format2);
        }
        if (l.longValue() > UNSIGNED_INT_MAX) {
            String format3 = String.format(Locale.ROOT, "model [%s] has a size [%s] larger than the max size [%s]", str, l, Long.valueOf(UNSIGNED_INT_MAX));
            logger.error(format3);
            throw new IllegalStateException(format3);
        }
        ByteBuffer allocate = ByteBuffer.allocate(NUM_BYTES_IN_PRELUDE);
        allocate.putInt(l.intValue());
        outputStream.write(allocate.array());
        return l.longValue();
    }
}
