package org.elasticsearch.xpack.esql.planner;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.elasticsearch.compute.aggregation.IntermediateStateDesc;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinct;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Max;
import org.elasticsearch.xpack.esql.expression.function.aggregate.MedianAbsoluteDeviation;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.esql.expression.function.aggregate.NumericAggregate;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
import org.elasticsearch.xpack.esql.type.EsqlDataTypes;
import org.elasticsearch.xpack.ql.expression.Alias;
import org.elasticsearch.xpack.ql.expression.AttributeMap;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.FieldAttribute;
import org.elasticsearch.xpack.ql.expression.MetadataAttribute;
import org.elasticsearch.xpack.ql.expression.NamedExpression;
import org.elasticsearch.xpack.ql.expression.ReferenceAttribute;
import org.elasticsearch.xpack.ql.expression.function.Function;
import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.ql.expression.function.aggregate.SpatialAggregateFunction;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;
import org.elasticsearch.xpack.ql.type.DataTypes;

/* loaded from: input_file:org/elasticsearch/xpack/esql/planner/AggregateMapper.class */
public class AggregateMapper {
    static final List<String> NUMERIC;
    static final List<String> SPATIAL;
    static final List<? extends Class<? extends Function>> AGG_FUNCTIONS;
    private final Map<AggDef, List<IntermediateStateDesc>> mapper;
    private final HashMap<Expression, List<? extends NamedExpression>> cache;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.elasticsearch.xpack.esql.planner.AggregateMapper$1, reason: invalid class name */
    /* loaded from: input_file:org/elasticsearch/xpack/esql/planner/AggregateMapper$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$elasticsearch$compute$data$ElementType = new int[ElementType.values().length];

        static {
            try {
                $SwitchMap$org$elasticsearch$compute$data$ElementType[ElementType.BOOLEAN.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$elasticsearch$compute$data$ElementType[ElementType.BYTES_REF.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$elasticsearch$compute$data$ElementType[ElementType.INT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$elasticsearch$compute$data$ElementType[ElementType.LONG.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$elasticsearch$compute$data$ElementType[ElementType.DOUBLE.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/elasticsearch/xpack/esql/planner/AggregateMapper$AggDef.class */
    public static final class AggDef extends Record {
        private final Class<?> aggClazz;
        private final String type;
        private final String extra;
        private final boolean grouping;

        AggDef(Class<?> cls, String str, String str2, boolean z) {
            this.aggClazz = cls;
            this.type = str;
            this.extra = str2;
            this.grouping = z;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, AggDef.class), AggDef.class, "aggClazz;type;extra;grouping", "FIELD:Lorg/elasticsearch/xpack/esql/planner/AggregateMapper$AggDef;->aggClazz:Ljava/lang/Class;", "FIELD:Lorg/elasticsearch/xpack/esql/planner/AggregateMapper$AggDef;->type:Ljava/lang/String;", "FIELD:Lorg/elasticsearch/xpack/esql/planner/AggregateMapper$AggDef;->extra:Ljava/lang/String;", "FIELD:Lorg/elasticsearch/xpack/esql/planner/AggregateMapper$AggDef;->grouping:Z").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, AggDef.class), AggDef.class, "aggClazz;type;extra;grouping", "FIELD:Lorg/elasticsearch/xpack/esql/planner/AggregateMapper$AggDef;->aggClazz:Ljava/lang/Class;", "FIELD:Lorg/elasticsearch/xpack/esql/planner/AggregateMapper$AggDef;->type:Ljava/lang/String;", "FIELD:Lorg/elasticsearch/xpack/esql/planner/AggregateMapper$AggDef;->extra:Ljava/lang/String;", "FIELD:Lorg/elasticsearch/xpack/esql/planner/AggregateMapper$AggDef;->grouping:Z").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, AggDef.class, Object.class), AggDef.class, "aggClazz;type;extra;grouping", "FIELD:Lorg/elasticsearch/xpack/esql/planner/AggregateMapper$AggDef;->aggClazz:Ljava/lang/Class;", "FIELD:Lorg/elasticsearch/xpack/esql/planner/AggregateMapper$AggDef;->type:Ljava/lang/String;", "FIELD:Lorg/elasticsearch/xpack/esql/planner/AggregateMapper$AggDef;->extra:Ljava/lang/String;", "FIELD:Lorg/elasticsearch/xpack/esql/planner/AggregateMapper$AggDef;->grouping:Z").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public Class<?> aggClazz() {
            return this.aggClazz;
        }

        public String type() {
            return this.type;
        }

        public String extra() {
            return this.extra;
        }

        public boolean grouping() {
            return this.grouping;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public AggregateMapper() {
        this(AGG_FUNCTIONS);
    }

    AggregateMapper(List<? extends Class<? extends Function>> list) {
        this.cache = new HashMap<>();
        this.mapper = (Map) list.stream().flatMap(AggregateMapper::typeAndNames).flatMap(AggregateMapper::groupingAndNonGrouping).collect(Collectors.toUnmodifiableMap(aggDef -> {
            return aggDef;
        }, AggregateMapper::lookupIntermediateState));
    }

    public List<? extends NamedExpression> mapNonGrouping(List<? extends Expression> list) {
        return doMapping(list, false);
    }

    public List<? extends NamedExpression> mapNonGrouping(Expression expression) {
        return map(expression, false).toList();
    }

    public List<? extends NamedExpression> mapGrouping(List<? extends Expression> list) {
        return doMapping(list, true);
    }

    private List<? extends NamedExpression> doMapping(List<? extends Expression> list, boolean z) {
        AttributeMap attributeMap = new AttributeMap();
        list.stream().flatMap(expression -> {
            return map(expression, z);
        }).forEach(namedExpression -> {
            attributeMap.put(namedExpression.toAttribute(), namedExpression);
        });
        return attributeMap.values().stream().toList();
    }

    public List<? extends NamedExpression> mapGrouping(Expression expression) {
        return map(expression, true).toList();
    }

    private Stream<? extends NamedExpression> map(Expression expression, boolean z) {
        return this.cache.computeIfAbsent(Alias.unwrap(expression), expression2 -> {
            return computeEntryForAgg(expression2, z);
        }).stream();
    }

    private List<? extends NamedExpression> computeEntryForAgg(Expression expression, boolean z) {
        AggDef aggDefOrNull = aggDefOrNull(expression, z);
        if (aggDefOrNull != null) {
            return isToNE(getNonNull(aggDefOrNull)).toList();
        }
        if ((expression instanceof FieldAttribute) || (expression instanceof MetadataAttribute) || (expression instanceof ReferenceAttribute)) {
            return List.of();
        }
        throw new EsqlIllegalArgumentException("unknown agg: " + expression.getClass() + ": " + expression);
    }

    private List<IntermediateStateDesc> getNonNull(AggDef aggDef) {
        List<IntermediateStateDesc> list = this.mapper.get(aggDef);
        if (list == null) {
            throw new EsqlIllegalArgumentException("Cannot find intermediate state for: " + aggDef);
        }
        return list;
    }

    private static Stream<Tuple<Class<?>, Tuple<String, String>>> typeAndNames(Class<?> cls) {
        List<String> list;
        List of = List.of("");
        if (NumericAggregate.class.isAssignableFrom(cls)) {
            list = NUMERIC;
        } else if (cls == Count.class) {
            list = List.of("");
        } else if (SpatialAggregateFunction.class.isAssignableFrom(cls)) {
            list = SPATIAL;
            of = List.of("SourceValues", "DocValues");
        } else if (Values.class.isAssignableFrom(cls)) {
            list = List.of("Int", "Long", "Double", "Boolean", "BytesRef");
        } else {
            if (!$assertionsDisabled && cls != CountDistinct.class) {
                throw new AssertionError("Expected CountDistinct, got: " + cls);
            }
            list = Stream.concat(NUMERIC.stream(), Stream.of((Object[]) new String[]{"Boolean", "BytesRef"})).toList();
        }
        return combinations(list, of).map(tuple -> {
            return new Tuple(cls, tuple);
        });
    }

    private static Stream<Tuple<String, String>> combinations(List<String> list, List<String> list2) {
        return list.stream().flatMap(str -> {
            return list2.stream().map(str -> {
                return new Tuple(str, str);
            });
        });
    }

    private static Stream<AggDef> groupingAndNonGrouping(Tuple<Class<?>, Tuple<String, String>> tuple) {
        return Stream.of((Object[]) new AggDef[]{new AggDef((Class) tuple.v1(), (String) ((Tuple) tuple.v2()).v1(), (String) ((Tuple) tuple.v2()).v2(), true), new AggDef((Class) tuple.v1(), (String) ((Tuple) tuple.v2()).v1(), (String) ((Tuple) tuple.v2()).v2(), false)});
    }

    private static AggDef aggDefOrNull(Expression expression, boolean z) {
        if (!(expression instanceof AggregateFunction)) {
            return null;
        }
        AggregateFunction aggregateFunction = (AggregateFunction) expression;
        return new AggDef(aggregateFunction.getClass(), dataTypeToString(aggregateFunction.field().dataType(), aggregateFunction.getClass()), expression instanceof SpatialCentroid ? "SourceValues" : "", z);
    }

    private static List<IntermediateStateDesc> lookupIntermediateState(AggDef aggDef) {
        try {
            return (List) lookup(aggDef.aggClazz(), aggDef.type(), aggDef.extra(), aggDef.grouping()).invokeExact();
        } catch (Throwable th) {
            throw new EsqlIllegalArgumentException(th);
        }
    }

    private static MethodHandle lookup(Class<?> cls, String str, String str2, boolean z) {
        try {
            return MethodHandles.lookup().findStatic(Class.forName(determineAggName(cls, str, str2, z)), "intermediateStateDesc", MethodType.methodType(List.class));
        } catch (ClassNotFoundException | IllegalAccessException | NoSuchMethodException e) {
            throw new EsqlIllegalArgumentException(e);
        }
    }

    private static String determineAggName(Class<?> cls, String str, String str2, boolean z) {
        StringBuilder sb = new StringBuilder();
        sb.append(determinePackageName(cls)).append(".");
        sb.append(cls.getSimpleName());
        sb.append(str);
        sb.append(str2);
        sb.append(z ? "Grouping" : "");
        sb.append("AggregatorFunction");
        return sb.toString();
    }

    private static String determinePackageName(Class<?> cls) {
        return cls.getSimpleName().startsWith("Spatial") ? "org.elasticsearch.compute.aggregation.spatial" : "org.elasticsearch.compute.aggregation";
    }

    private static Stream<NamedExpression> isToNE(List<IntermediateStateDesc> list) {
        return list.stream().map(intermediateStateDesc -> {
            return new ReferenceAttribute(Source.EMPTY, intermediateStateDesc.name(), toDataType(intermediateStateDesc.type()));
        });
    }

    private static DataType toDataType(ElementType elementType) {
        switch (AnonymousClass1.$SwitchMap$org$elasticsearch$compute$data$ElementType[elementType.ordinal()]) {
            case 1:
                return DataTypes.BOOLEAN;
            case 2:
                return DataTypes.KEYWORD;
            case 3:
                return DataTypes.INTEGER;
            case 4:
                return DataTypes.LONG;
            case 5:
                return DataTypes.DOUBLE;
            default:
                throw new EsqlIllegalArgumentException("unsupported agg type: " + elementType);
        }
    }

    private static String dataTypeToString(DataType dataType, Class<?> cls) {
        if (cls == Count.class) {
            return "";
        }
        if (dataType.equals(DataTypes.BOOLEAN)) {
            return "Boolean";
        }
        if (dataType.equals(DataTypes.INTEGER)) {
            return "Int";
        }
        if (dataType.equals(DataTypes.LONG) || dataType.equals(DataTypes.DATETIME)) {
            return "Long";
        }
        if (dataType.equals(DataTypes.DOUBLE)) {
            return "Double";
        }
        if (dataType.equals(DataTypes.KEYWORD) || dataType.equals(DataTypes.IP) || dataType.equals(DataTypes.VERSION) || dataType.equals(DataTypes.TEXT)) {
            return "BytesRef";
        }
        if (dataType.equals(EsqlDataTypes.GEO_POINT)) {
            return "GeoPoint";
        }
        if (dataType.equals(EsqlDataTypes.CARTESIAN_POINT)) {
            return "CartesianPoint";
        }
        throw new EsqlIllegalArgumentException("illegal agg type: " + dataType.typeName());
    }

    private static Expression unwrapAlias(Expression expression) {
        return expression instanceof Alias ? ((Alias) expression).child() : expression;
    }

    static {
        $assertionsDisabled = !AggregateMapper.class.desiredAssertionStatus();
        NUMERIC = List.of("Int", "Long", "Double");
        SPATIAL = List.of("GeoPoint", "CartesianPoint");
        AGG_FUNCTIONS = List.of(Count.class, CountDistinct.class, Max.class, MedianAbsoluteDeviation.class, Min.class, Percentile.class, SpatialCentroid.class, Sum.class, Values.class);
    }
}
