package org.apache.ignite.internal.sql.engine.rel.agg;

import com.google.common.collect.ImmutableList;
import java.math.BigDecimal;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.ignite.internal.lang.IgniteStringFormatter;
import org.apache.ignite.internal.sql.engine.rel.IgniteProject;
import org.apache.ignite.internal.sql.engine.rel.IgniteRel;
import org.apache.ignite.internal.sql.engine.sql.fun.IgniteSqlOperatorTable;
import org.apache.ignite.internal.sql.engine.type.IgniteTypeFactory;
import org.apache.ignite.internal.sql.engine.util.Commons;
import org.jetbrains.annotations.TestOnly;

/* loaded from: input_file:org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.class */
public class MapReduceAggregates {
    private static final Set<String> AGG_SUPPORTING_MAP_REDUCE;
    private static final MakeReduceExpr USE_INPUT_FIELD;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates$AggregateRelBuilder.class */
    public interface AggregateRelBuilder {
        IgniteRel makeMapAgg(RelOptCluster relOptCluster, RelNode relNode, ImmutableBitSet immutableBitSet, List<ImmutableBitSet> list, List<AggregateCall> list2);

        IgniteRel makeProject(RelOptCluster relOptCluster, RelNode relNode, List<RexNode> list, RelDataType relDataType);

        IgniteRel makeReduceAgg(RelOptCluster relOptCluster, RelNode relNode, ImmutableBitSet immutableBitSet, List<ImmutableBitSet> list, List<AggregateCall> list2, RelDataType relDataType);
    }

    /* JADX INFO: Access modifiers changed from: private */
    @FunctionalInterface
    /* loaded from: input_file:org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates$MakeReduceExpr.class */
    public interface MakeReduceExpr {
        RexNode makeExpr(RexBuilder rexBuilder, RelNode relNode, List<Integer> list, IgniteTypeFactory igniteTypeFactory);
    }

    /* loaded from: input_file:org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates$MapReduceAgg.class */
    public static class MapReduceAgg {
        final List<Integer> argList;
        final List<AggregateCall> mapCalls;
        final List<AggregateCall> reduceCalls;
        final MakeReduceExpr makeReduceInputExpr;
        final MakeReduceExpr makeReduceOutputExpr;

        MapReduceAgg(List<Integer> list, AggregateCall aggregateCall, AggregateCall aggregateCall2, MakeReduceExpr makeReduceExpr) {
            this(list, List.of(aggregateCall), MapReduceAggregates.USE_INPUT_FIELD, List.of(aggregateCall2), makeReduceExpr);
        }

        MapReduceAgg(List<Integer> list, List<AggregateCall> list2, MakeReduceExpr makeReduceExpr, List<AggregateCall> list3, MakeReduceExpr makeReduceExpr2) {
            this.argList = list;
            this.mapCalls = list2;
            this.reduceCalls = list3;
            this.makeReduceInputExpr = makeReduceExpr;
            this.makeReduceOutputExpr = makeReduceExpr2;
        }

        @TestOnly
        public AggregateCall getReduceCall() {
            return this.reduceCalls.get(0);
        }
    }

    private MapReduceAggregates() {
    }

    public static boolean canBeImplementedAsMapReduce(List<AggregateCall> list) {
        Iterator<AggregateCall> it = list.iterator();
        while (it.hasNext()) {
            if (!AGG_SUPPORTING_MAP_REDUCE.contains(it.next().getAggregation().getName())) {
                return false;
            }
        }
        return true;
    }

    public static IgniteRel buildAggregates(LogicalAggregate logicalAggregate, AggregateRelBuilder aggregateRelBuilder, Mapping mapping) {
        RelNode relNode;
        ArrayList<MapReduceAgg> arrayList = new ArrayList(logicalAggregate.getAggCallList().size());
        int cardinality = logicalAggregate.getGroupSet().cardinality();
        ArrayList arrayList2 = new ArrayList(logicalAggregate.getAggCallList().size());
        for (AggregateCall aggregateCall : logicalAggregate.getAggCallList()) {
            MapReduceAgg createMapReduceAggCall = createMapReduceAggCall(Commons.cluster(), aggregateCall, cardinality, logicalAggregate.getInput().getRowType(), logicalAggregate.getGroupCount() == 0 || aggregateCall.hasFilter());
            cardinality += createMapReduceAggCall.reduceCalls.size();
            arrayList.add(createMapReduceAggCall);
            arrayList2.addAll(createMapReduceAggCall.mapCalls);
        }
        if (!$assertionsDisabled && arrayList2.size() < logicalAggregate.getAggCallList().size()) {
            throw new AssertionError(IgniteStringFormatter.format("The number of MAP aggregates is not correct. Original: {}\nMAP: {}", new Object[]{logicalAggregate.getAggCallList(), arrayList2}));
        }
        RelNode makeMapAgg = aggregateRelBuilder.makeMapAgg(logicalAggregate.getCluster(), logicalAggregate.getInput(), logicalAggregate.getGroupSet(), logicalAggregate.getGroupSets(), arrayList2);
        RelDataTypeFactory.Builder builder = new RelDataTypeFactory.Builder(Commons.typeFactory());
        int cardinality2 = logicalAggregate.getGroupSet().cardinality();
        boolean z = true;
        for (int i = 0; i < cardinality2; i++) {
            builder.add("f" + builder.getFieldCount(), ((RelDataTypeField) logicalAggregate.getRowType().getFieldList().get(i)).getType());
        }
        RexBuilder rexBuilder = logicalAggregate.getCluster().getRexBuilder();
        IgniteTypeFactory typeFactory = logicalAggregate.getCluster().getTypeFactory();
        ArrayList arrayList3 = new ArrayList();
        for (int i2 = 0; i2 < makeMapAgg.getRowType().getFieldList().size(); i2++) {
            arrayList3.add(new RexInputRef(i2, ((RelDataTypeField) makeMapAgg.getRowType().getFieldList().get(i2)).getType()));
        }
        boolean z2 = false;
        int i3 = 0;
        for (int i4 = 0; i4 < arrayList.size(); i4++) {
            MapReduceAgg mapReduceAgg = (MapReduceAgg) arrayList.get(i4);
            int i5 = cardinality2 + i3;
            for (int i6 = 0; i6 < mapReduceAgg.reduceCalls.size(); i6++) {
                arrayList3.set(i5, mapReduceAgg.makeReduceInputExpr.makeExpr(rexBuilder, makeMapAgg, List.of(Integer.valueOf(i5)), typeFactory));
                if (mapReduceAgg.makeReduceInputExpr != USE_INPUT_FIELD) {
                    z2 = true;
                }
                i5++;
            }
            i3 += mapReduceAgg.reduceCalls.size();
        }
        if (z2) {
            RelDataTypeFactory.Builder builder2 = new RelDataTypeFactory.Builder(logicalAggregate.getCluster().getTypeFactory());
            for (int i7 = 0; i7 < arrayList3.size(); i7++) {
                builder2.add(String.valueOf(i7), arrayList3.get(i7).getType());
            }
            relNode = aggregateRelBuilder.makeProject(logicalAggregate.getCluster(), makeMapAgg, arrayList3, builder2.build());
        } else {
            relNode = makeMapAgg;
        }
        ArrayList arrayList4 = new ArrayList();
        ArrayList<Map.Entry> arrayList5 = new ArrayList(arrayList.size());
        for (MapReduceAgg mapReduceAgg2 : arrayList) {
            int i8 = 0;
            for (AggregateCall aggregateCall2 : mapReduceAgg2.reduceCalls) {
                builder.add("f" + i8 + "_" + builder.getFieldCount(), aggregateCall2.getType());
                arrayList4.add(aggregateCall2);
                i8++;
            }
            List<Integer> list = mapReduceAgg2.argList;
            MakeReduceExpr makeReduceExpr = mapReduceAgg2.makeReduceOutputExpr;
            arrayList5.add(new AbstractMap.SimpleEntry(list, makeReduceExpr));
            if (makeReduceExpr != USE_INPUT_FIELD) {
                z = false;
            }
        }
        RelDataType rowType = z ? logicalAggregate.getRowType() : builder.build();
        if (!$assertionsDisabled && arrayList2.size() > arrayList4.size()) {
            throw new AssertionError(IgniteStringFormatter.format("The number of MAP/REDUCE aggregates is not correct. MAP: {}\nREDUCE: {}", new Object[]{arrayList2, arrayList4}));
        }
        RelNode makeReduceAgg = aggregateRelBuilder.makeReduceAgg(logicalAggregate.getCluster(), relNode, Mappings.apply(mapping, logicalAggregate.getGroupSet()), (List) logicalAggregate.getGroupSets().stream().map(immutableBitSet -> {
            return Mappings.apply(mapping, immutableBitSet);
        }).collect(Collectors.toList()), arrayList4, rowType);
        if (z) {
            return makeReduceAgg;
        }
        ArrayList arrayList6 = new ArrayList(arrayList5.size() + cardinality2);
        for (int i9 = 0; i9 < cardinality2; i9++) {
            arrayList6.add(new RexInputRef(i9, ((RelDataTypeField) logicalAggregate.getRowType().getFieldList().get(i9)).getType()));
        }
        for (Map.Entry entry : arrayList5) {
            arrayList6.add(((MakeReduceExpr) entry.getValue()).makeExpr(rexBuilder, makeReduceAgg, (List) entry.getKey(), typeFactory));
        }
        if (!$assertionsDisabled && arrayList6.size() != logicalAggregate.getRowType().getFieldList().size()) {
            throw new AssertionError(IgniteStringFormatter.format("Projection size does not match. Expected: {} but got {}", new Object[]{Integer.valueOf(logicalAggregate.getRowType().getFieldList().size()), Integer.valueOf(arrayList6.size())}));
        }
        for (int i10 = 0; i10 < arrayList6.size(); i10++) {
            RexNode rexNode = (RexNode) arrayList6.get(i10);
            List fieldList = logicalAggregate.getRowType().getFieldList();
            if (!$assertionsDisabled && !rexNode.getType().equals(((RelDataTypeField) fieldList.get(i10)).getType())) {
                throw new AssertionError(IgniteStringFormatter.format("Type at position#{} does not match. Expected: {} but got {}.\nREDUCE aggregates: {}\nRow: {}.\nExpr: {}", new Object[]{Integer.valueOf(i10), rexNode.getType(), ((RelDataTypeField) fieldList.get(i10)).getType(), arrayList4, fieldList, rexNode}));
            }
        }
        return new IgniteProject(logicalAggregate.getCluster(), makeReduceAgg.getTraitSet(), makeReduceAgg, arrayList6, logicalAggregate.getRowType());
    }

    public static MapReduceAgg createMapReduceAggCall(RelOptCluster relOptCluster, AggregateCall aggregateCall, int i, RelDataType relDataType, boolean z) {
        String name = aggregateCall.getAggregation().getName();
        if (!$assertionsDisabled && !AGG_SUPPORTING_MAP_REDUCE.contains(name)) {
            throw new AssertionError("Aggregate does not support MAP/REDUCE " + aggregateCall);
        }
        boolean z2 = -1;
        switch (name.hashCode()) {
            case 65202:
                if (name.equals("AVG")) {
                    z2 = true;
                    break;
                }
                break;
            case 64313583:
                if (name.equals("COUNT")) {
                    z2 = false;
                    break;
                }
                break;
        }
        switch (z2) {
            case false:
                return createCountAgg(aggregateCall, i);
            case true:
                return createAvgAgg(relOptCluster, aggregateCall, i, relDataType, z);
            default:
                return createSimpleAgg(aggregateCall, i);
        }
    }

    private static MapReduceAgg createCountAgg(AggregateCall aggregateCall, int i) {
        List of = List.of(Integer.valueOf(i));
        return new MapReduceAgg(of, aggregateCall, AggregateCall.create(SqlStdOperatorTable.SUM0, aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.ignoreNulls(), ImmutableList.of(), of, -1, (ImmutableBitSet) null, aggregateCall.collation, aggregateCall.type, "COUNT_" + i + "_MAP_SUM"), (rexBuilder, relNode, list, igniteTypeFactory) -> {
            return rexBuilder.makeCast(igniteTypeFactory.createSqlType(SqlTypeName.BIGINT), rexBuilder.makeInputRef(relNode, ((Integer) list.get(0)).intValue()), true, false);
        });
    }

    private static MapReduceAgg createSimpleAgg(AggregateCall aggregateCall, int i) {
        List of = List.of(Integer.valueOf(i));
        return new MapReduceAgg(of, aggregateCall, AggregateCall.create(aggregateCall.getAggregation(), aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.ignoreNulls(), ImmutableList.of(), of, -1, aggregateCall.distinctKeys, aggregateCall.collation, aggregateCall.type, aggregateCall.name), USE_INPUT_FIELD);
    }

    private static MapReduceAgg createAvgAgg(RelOptCluster relOptCluster, AggregateCall aggregateCall, int i, RelDataType relDataType, boolean z) {
        RelDataTypeFactory typeFactory = relOptCluster.getTypeFactory();
        RelDataTypeSystem typeSystem = typeFactory.getTypeSystem();
        RelDataType type = ((RelDataTypeField) relDataType.getFieldList().get(((Integer) aggregateCall.getArgList().get(0)).intValue())).getType();
        if (type.getSqlTypeName() == SqlTypeName.NULL) {
            return createSimpleAgg(aggregateCall, i);
        }
        RelDataType deriveSumType = typeSystem.deriveSumType(typeFactory, type);
        if (z) {
            deriveSumType = typeFactory.createTypeWithNullability(deriveSumType, true);
        }
        AggregateCall create = AggregateCall.create(SqlStdOperatorTable.SUM, aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.ignoreNulls(), ImmutableList.of(), aggregateCall.getArgList(), aggregateCall.filterArg, (ImmutableBitSet) null, aggregateCall.collation, deriveSumType, "AVG_SUM" + i);
        AggregateCall create2 = AggregateCall.create(SqlStdOperatorTable.COUNT, aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.ignoreNulls(), ImmutableList.of(), aggregateCall.getArgList(), aggregateCall.filterArg, (ImmutableBitSet) null, aggregateCall.collation, typeFactory.createSqlType(SqlTypeName.BIGINT), "AVG_COUNT" + i);
        List of = List.of(Integer.valueOf(i));
        RelDataType deriveSumType2 = typeSystem.deriveSumType(typeFactory, deriveSumType);
        if (z) {
            deriveSumType2 = typeFactory.createTypeWithNullability(deriveSumType2, true);
        }
        AggregateCall create3 = AggregateCall.create(SqlStdOperatorTable.SUM, aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.ignoreNulls(), ImmutableList.of(), of, -1, (ImmutableBitSet) null, aggregateCall.collation, deriveSumType2, "AVG_SUM" + i);
        RelDataType deriveSumType3 = typeSystem.deriveSumType(typeFactory, create2.type);
        AggregateCall create4 = AggregateCall.create(SqlStdOperatorTable.SUM0, aggregateCall.isDistinct(), aggregateCall.isApproximate(), aggregateCall.ignoreNulls(), ImmutableList.of(), List.of(Integer.valueOf(i + 1)), -1, (ImmutableBitSet) null, aggregateCall.collation, deriveSumType3, "AVG_SUM0" + i);
        RelDataType relDataType2 = deriveSumType2;
        return new MapReduceAgg(List.of(Integer.valueOf(i), Integer.valueOf(i + 1)), List.of(create, create2), (rexBuilder, relNode, list, igniteTypeFactory) -> {
            RexInputRef makeInputRef = rexBuilder.makeInputRef(relNode, ((Integer) list.get(0)).intValue());
            return ((Integer) list.get(0)).intValue() == i ? !SqlTypeUtil.equalSansNullability(relDataType2, makeInputRef.getType()) ? rexBuilder.makeCast(relDataType2, makeInputRef, true, false) : makeInputRef : rexBuilder.makeCast(create4.type, makeInputRef, true, false);
        }, List.of(create3, create4), (rexBuilder2, relNode2, list2, igniteTypeFactory2) -> {
            RexInputRef makeInputRef = rexBuilder2.makeInputRef(relNode2, ((Integer) list2.get(0)).intValue());
            RexNode makeInputRef2 = rexBuilder2.makeInputRef(relNode2, ((Integer) list2.get(1)).intValue());
            RexNode ensureType = rexBuilder2.ensureType(create.type, makeInputRef, true);
            RelDataType decimalOf = igniteTypeFactory2.decimalOf(aggregateCall.type);
            RexNode makeCall = rexBuilder2.makeCall(IgniteSqlOperatorTable.DECIMAL_DIVIDE, new RexNode[]{ensureType, makeInputRef2, rexBuilder2.makeExactLiteral(BigDecimal.valueOf(decimalOf.getPrecision()), typeFactory.createSqlType(SqlTypeName.INTEGER)), rexBuilder2.makeExactLiteral(BigDecimal.valueOf(decimalOf.getScale()), typeFactory.createSqlType(SqlTypeName.INTEGER))});
            if (aggregateCall.getType().getSqlTypeName() != SqlTypeName.DECIMAL) {
                makeCall = rexBuilder2.makeCast(aggregateCall.getType(), makeCall, false, false);
            }
            if (!z) {
                return makeCall;
            }
            return rexBuilder2.makeCall(SqlStdOperatorTable.CASE, new RexNode[]{rexBuilder2.makeCall(SqlStdOperatorTable.EQUALS, new RexNode[]{ensureType, rexBuilder2.makeExactLiteral(BigDecimal.ZERO, makeInputRef2.getType())}), rexBuilder2.makeNullLiteral(aggregateCall.getType()), makeCall});
        });
    }

    static {
        $assertionsDisabled = !MapReduceAggregates.class.desiredAssertionStatus();
        AGG_SUPPORTING_MAP_REDUCE = Set.of((Object[]) new String[]{"COUNT", "MIN", "MAX", "SUM", "$SUM0", "EVERY", "SOME", "ANY", "AVG", "SINGLE_VALUE", "ANY_VALUE"});
        USE_INPUT_FIELD = (rexBuilder, relNode, list, igniteTypeFactory) -> {
            return rexBuilder.makeInputRef(relNode, ((Integer) list.get(0)).intValue());
        };
    }
}
