package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import org.apache.lucene.analysis.CharArraySet;
import org.apache.lucene.analysis.CharacterUtils;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.UnicodeUtil;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer.class */
public final class UnigramTokenizer extends Tokenizer {
    private static final double K_UNK_PENALTY = 10.0d;
    static final String PREFIX = "▁";
    private final double minScore;
    private final double[] vocabScores;
    private final CharTrie neverSplit;
    private final CharArraySet neverSplitHash;
    private final Map<BytesRef, Integer> vocabToId;
    private final BytesTrie vocabTrie;
    private final int unknownTokenId;
    private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
    private final OffsetAttribute offsetAtt = addAttribute(OffsetAttribute.class);
    private final boolean fuseUnk = true;
    private byte[] normalizedByteBuffer = new byte[128];
    private final LinkedList<DelimitedToken.Encoded> tokens = new LinkedList<>();
    private final List<DelimitedToken.Encoded> tokenizedValues = new ArrayList();
    private final SimpleWhitespaceTokenizer whitespaceTokenizer = new SimpleWhitespaceTokenizer();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer$BestPathNode.class */
    public static class BestPathNode {
        private int id = -1;
        double score = 0.0d;
        private int startsAtBytePos = -1;
        private int startsAtCharPos = -1;

        private BestPathNode() {
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer$BytesTrie.class */
    public static class BytesTrie {
        private final Map<Byte, BytesTrie> children = new HashMap();
        private boolean isLeaf;

        BytesTrie() {
        }

        private void setLeaf(boolean z) {
            this.isLeaf = z;
        }

        private boolean isLeaf() {
            return this.isLeaf;
        }

        List<BytesRef> matchingPrefixes(BytesRef bytesRef) {
            ArrayList arrayList = new ArrayList();
            int i = 0;
            BytesTrie bytesTrie = this;
            for (int i2 = bytesRef.offset; i2 < bytesRef.length + bytesRef.offset && bytesTrie != null; i2++) {
                if (bytesTrie.isLeaf() && i > 0) {
                    arrayList.add(new BytesRef(bytesRef.bytes, bytesRef.offset, i));
                }
                bytesTrie = bytesTrie.children.get(Byte.valueOf(bytesRef.bytes[i2]));
                i++;
            }
            if (bytesTrie != null && bytesTrie.isLeaf() && i > 0) {
                arrayList.add(new BytesRef(bytesRef.bytes, bytesRef.offset, i));
            }
            return arrayList;
        }

        void insert(BytesRef bytesRef) {
            if (bytesRef.length == 0) {
                return;
            }
            BytesTrie bytesTrie = this;
            for (int i = 0; i < bytesRef.length; i++) {
                bytesTrie = bytesTrie.children.computeIfAbsent(Byte.valueOf(UnigramTokenizer.fromBytesRef(bytesRef, i)), b -> {
                    return new BytesTrie();
                });
            }
            bytesTrie.setLeaf(true);
        }

        public static BytesTrie build(Collection<BytesRef> collection) {
            BytesTrie bytesTrie = new BytesTrie();
            Iterator<BytesRef> it = collection.iterator();
            while (it.hasNext()) {
                bytesTrie.insert(it.next());
            }
            return bytesTrie;
        }
    }

    @FunctionalInterface
    /* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer$IntToIntFunction.class */
    public interface IntToIntFunction {
        int apply(int i);
    }

    /* loaded from: input_file:org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer$SimpleWhitespaceTokenizer.class */
    class SimpleWhitespaceTokenizer {
        private static final int IO_BUFFER_SIZE = 4096;
        static final /* synthetic */ boolean $assertionsDisabled;
        private int offset = 0;
        private int bufferIndex = 0;
        private int dataLen = 0;
        private int finalOffset = 0;
        private final CharacterUtils.CharacterBuffer ioBuffer = CharacterUtils.newCharacterBuffer(IO_BUFFER_SIZE);

        SimpleWhitespaceTokenizer() {
        }

        void reset() {
            this.bufferIndex = 0;
            this.offset = 0;
            this.dataLen = 0;
            this.finalOffset = 0;
            this.ioBuffer.reset();
        }

        @Nullable
        DelimitedToken next() throws IOException {
            int i = 0;
            int i2 = -1;
            int i3 = -1;
            char[] buffer = UnigramTokenizer.this.termAtt.buffer();
            while (true) {
                if (this.bufferIndex >= this.dataLen) {
                    this.offset += this.dataLen;
                    CharacterUtils.fill(this.ioBuffer, UnigramTokenizer.this.input);
                    if (this.ioBuffer.getLength() == 0) {
                        this.dataLen = 0;
                        if (i <= 0) {
                            this.finalOffset = this.offset;
                            return null;
                        }
                    } else {
                        this.dataLen = this.ioBuffer.getLength();
                        this.bufferIndex = 0;
                    }
                }
                int codePointAt = Character.codePointAt(this.ioBuffer.getBuffer(), this.bufferIndex, this.ioBuffer.getLength());
                int charCount = Character.charCount(codePointAt);
                this.bufferIndex += charCount;
                if (!Character.isWhitespace(codePointAt)) {
                    if (i == 0) {
                        if (!$assertionsDisabled && i2 != -1) {
                            throw new AssertionError();
                        }
                        i2 = (this.offset + this.bufferIndex) - charCount;
                        i3 = i2;
                    } else if (i >= buffer.length - 1) {
                        buffer = UnigramTokenizer.this.termAtt.resizeBuffer(2 + i);
                    }
                    i3 += charCount;
                    i += Character.toChars(codePointAt, buffer, i);
                } else if (i > 0) {
                    break;
                }
            }
            UnigramTokenizer.this.termAtt.setLength(i);
            if (!$assertionsDisabled && i2 == -1) {
                throw new AssertionError();
            }
            int i4 = i3;
            this.finalOffset = i4;
            return new DelimitedToken(UnigramTokenizer.this.termAtt, i2, i4);
        }

        static {
            $assertionsDisabled = !UnigramTokenizer.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static UnigramTokenizer build(List<String> list, List<String> list2, double[] dArr, String str) {
        if (list2.isEmpty()) {
            throw new IllegalArgumentException("vocab empty");
        }
        if (str == null) {
            throw new IllegalArgumentException("unknown token ID");
        }
        CharArraySet charArraySet = new CharArraySet(list, false);
        CharTrie build = CharTrie.build(list);
        if (list2.size() != dArr.length) {
            throw new IllegalArgumentException(Strings.format("provided vocabulary [%s] and scores [%s] must have the same size", new Object[]{Integer.valueOf(list2.size()), Integer.valueOf(dArr.length)}));
        }
        int size = list2.size();
        BytesTrie bytesTrie = new BytesTrie();
        Map newHashMapWithExpectedSize = Maps.newHashMapWithExpectedSize(size);
        int i = 0;
        double d = Double.POSITIVE_INFINITY;
        for (String str2 : list2) {
            d = Double.min(d, dArr[i]);
            BytesRef bytesRef = new BytesRef(str2);
            int i2 = i;
            i++;
            newHashMapWithExpectedSize.put(bytesRef, Integer.valueOf(i2));
            bytesTrie.insert(bytesRef);
        }
        return new UnigramTokenizer(d, dArr, build, charArraySet, newHashMapWithExpectedSize, bytesTrie, ((Integer) Optional.ofNullable((Integer) newHashMapWithExpectedSize.get(new BytesRef(str))).orElseThrow(() -> {
            return new IllegalArgumentException("provided vocabulary does not contain the unknown token of [" + str + "]");
        })).intValue());
    }

    public UnigramTokenizer(double d, double[] dArr, CharTrie charTrie, CharArraySet charArraySet, Map<BytesRef, Integer> map, BytesTrie bytesTrie, int i) {
        this.minScore = d;
        this.neverSplit = charTrie;
        this.neverSplitHash = charArraySet;
        this.vocabToId = map;
        this.vocabTrie = bytesTrie;
        this.unknownTokenId = i;
        this.vocabScores = dArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public List<DelimitedToken.Encoded> getTokenizedValues() {
        return this.tokenizedValues;
    }

    public void reset() throws IOException {
        super.reset();
        this.tokens.clear();
        this.tokenizedValues.clear();
        this.whitespaceTokenizer.reset();
    }

    public void end() throws IOException {
        super.end();
        this.offsetAtt.setOffset(correctOffset(this.whitespaceTokenizer.finalOffset), correctOffset(this.whitespaceTokenizer.finalOffset));
    }

    private void popFromTokens() {
        if (this.tokens.isEmpty()) {
            return;
        }
        DelimitedToken.Encoded removeFirst = this.tokens.removeFirst();
        this.tokenizedValues.add(removeFirst);
        this.termAtt.setEmpty().append(removeFirst.charSequence());
        this.offsetAtt.setOffset(removeFirst.startOffset(), removeFirst.endOffset());
    }

    public boolean incrementToken() throws IOException {
        clearAttributes();
        if (!this.tokens.isEmpty()) {
            popFromTokens();
            return true;
        }
        DelimitedToken next = this.whitespaceTokenizer.next();
        if (next == null) {
            return false;
        }
        if (this.neverSplitHash.contains(next.charSequence())) {
            this.tokens.add(new DelimitedToken.Encoded(next.charSequence().toString(), ((Integer) Objects.requireNonNullElse(this.vocabToId.get(new BytesRef(next.charSequence())), Integer.valueOf(this.unknownTokenId))).intValue(), correctOffset(next.startOffset()), correctOffset(next.endOffset())));
            popFromTokens();
            return true;
        }
        int startOffset = next.startOffset();
        Iterator<DelimitedToken> it = TokenizerUtils.splitOutNeverSplit(next.charSequence(), this.neverSplit, this.neverSplitHash).iterator();
        while (it.hasNext()) {
            DelimitedToken next2 = it.next();
            if (this.neverSplitHash.contains(next2.charSequence())) {
                Integer num = this.vocabToId.get(new BytesRef(next2.charSequence()));
                this.tokens.add(num == null ? new DelimitedToken.Encoded(next2.charSequence().toString(), this.unknownTokenId, correctOffset(next2.startOffset() + startOffset), correctOffset(next2.endOffset() + startOffset)) : new DelimitedToken.Encoded(next2.charSequence().toString(), num.intValue(), correctOffset(next2.startOffset() + startOffset), correctOffset(next2.endOffset() + startOffset)));
            } else {
                this.tokens.addAll(tokenize(MultiCharSequence.from(PREFIX, next2.charSequence()), i -> {
                    int startOffset2 = i + startOffset + next2.startOffset();
                    if (i > 0) {
                        startOffset2 -= PREFIX.length();
                    }
                    return correctOffset(startOffset2);
                }));
            }
        }
        popFromTokens();
        return true;
    }

    List<DelimitedToken.Encoded> tokenize(CharSequence charSequence, IntToIntFunction intToIntFunction) {
        int calcUTF16toUTF8Length = UnicodeUtil.calcUTF16toUTF8Length(charSequence, 0, charSequence.length());
        if (calcUTF16toUTF8Length > this.normalizedByteBuffer.length) {
            this.normalizedByteBuffer = new byte[calcUTF16toUTF8Length + 1];
        }
        int UTF16toUTF8 = UnicodeUtil.UTF16toUTF8(charSequence, 0, charSequence.length(), this.normalizedByteBuffer);
        double d = this.minScore - K_UNK_PENALTY;
        BestPathNode[] bestPathNodeArr = new BestPathNode[UTF16toUTF8 + 1];
        int i = 0;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= charSequence.length()) {
                break;
            }
            double d2 = bestPathNodeArr[i] == null ? 0.0d : bestPathNodeArr[i].score;
            int i4 = i3 + 1 < charSequence.length() && Character.isSurrogatePair(charSequence.charAt(i3), charSequence.charAt(i3 + 1)) ? 2 : 1;
            int calcUTF16toUTF8Length2 = UnicodeUtil.calcUTF16toUTF8Length(charSequence, i3, i4);
            boolean z = false;
            for (BytesRef bytesRef : this.vocabTrie.matchingPrefixes(new BytesRef(this.normalizedByteBuffer, i, UTF16toUTF8 - i))) {
                int i5 = i + bytesRef.length;
                int intValue = this.vocabToId.get(bytesRef).intValue();
                double d3 = this.vocabScores[intValue];
                BestPathNode bestPathNode = bestPathNodeArr[i5];
                double d4 = d3 + d2;
                if (bestPathNode == null || d4 > bestPathNode.score) {
                    if (bestPathNode == null) {
                        bestPathNode = new BestPathNode();
                        bestPathNodeArr[i5] = bestPathNode;
                    }
                    bestPathNode.id = intValue;
                    bestPathNode.score = d4;
                    bestPathNode.startsAtBytePos = i;
                    bestPathNode.startsAtCharPos = i3;
                }
                z = z || i5 - i == calcUTF16toUTF8Length2;
            }
            if (!z) {
                BestPathNode bestPathNode2 = bestPathNodeArr[i + calcUTF16toUTF8Length2];
                double d5 = d + d2;
                if (bestPathNode2 == null || d5 > bestPathNode2.score) {
                    if (bestPathNode2 == null) {
                        bestPathNode2 = new BestPathNode();
                        bestPathNodeArr[i + calcUTF16toUTF8Length2] = bestPathNode2;
                    }
                    bestPathNode2.id = this.unknownTokenId;
                    bestPathNode2.score = d5;
                    bestPathNode2.startsAtBytePos = i;
                    bestPathNode2.startsAtCharPos = i3;
                }
            }
            i += calcUTF16toUTF8Length2;
            i2 = i3 + i4;
        }
        int i6 = UTF16toUTF8;
        int length = charSequence.length();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        while (i6 > 0) {
            BestPathNode bestPathNode3 = bestPathNodeArr[i6];
            int i7 = bestPathNode3.startsAtBytePos;
            if (bestPathNode3.id == this.unknownTokenId) {
                arrayList.add(new DelimitedToken.Encoded(new String(this.normalizedByteBuffer, i7, i6 - i7, StandardCharsets.UTF_8), this.unknownTokenId, intToIntFunction.apply(bestPathNode3.startsAtCharPos), intToIntFunction.apply(length)));
            } else {
                if (!arrayList.isEmpty()) {
                    Collections.reverse(arrayList);
                    arrayList2.add(DelimitedToken.Encoded.mergeEncodedTokens(arrayList));
                    arrayList.clear();
                }
                arrayList2.add(new DelimitedToken.Encoded(new String(this.normalizedByteBuffer, i7, i6 - i7, StandardCharsets.UTF_8), bestPathNode3.id, intToIntFunction.apply(bestPathNode3.startsAtCharPos), intToIntFunction.apply(length)));
            }
            i6 = i7;
            length = bestPathNode3.startsAtCharPos;
        }
        if (!arrayList.isEmpty()) {
            Collections.reverse(arrayList);
            arrayList2.add(DelimitedToken.Encoded.mergeEncodedTokens(arrayList));
            arrayList.clear();
        }
        Collections.reverse(arrayList2);
        return arrayList2;
    }

    private static byte fromBytesRef(BytesRef bytesRef, int i) {
        return bytesRef.bytes[i + bytesRef.offset];
    }
}
