/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite3.internal.sql.engine.metadata;

import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import java.lang.reflect.Method;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Objects;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Intersect;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Minus;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.metadata.MetadataHandler;
import org.apache.calcite.rel.metadata.ReflectiveRelMetadataProvider;
import org.apache.calcite.rel.metadata.RelColumnOrigin;
import org.apache.calcite.rel.metadata.RelMdRowCount;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.metadata.RelMetadataProvider;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.BuiltInMethod;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.mapping.IntPair;
import org.apache.ignite3.internal.sql.engine.rel.IgniteAggregate;
import org.apache.ignite3.internal.sql.engine.rel.IgniteLimit;
import org.apache.ignite3.internal.sql.engine.rel.IgniteSortedIndexSpool;
import org.apache.ignite3.internal.sql.engine.schema.IgniteTable;
import org.jetbrains.annotations.Nullable;

public class IgniteMdRowCount
extends RelMdRowCount {
    public static final double NON_EQUI_COEFF = 0.7;
    public static final RelMetadataProvider SOURCE = ReflectiveRelMetadataProvider.reflectiveSource((Method)BuiltInMethod.ROW_COUNT.method, (MetadataHandler)new IgniteMdRowCount());

    @Nullable
    public Double getRowCount(Join rel, RelMetadataQuery mq) {
        return IgniteMdRowCount.joinRowCount(mq, rel);
    }

    public Double getRowCount(Sort rel, RelMetadataQuery mq) {
        return rel.estimateRowCount(mq);
    }

    public double getRowCount(IgniteSortedIndexSpool rel, RelMetadataQuery mq) {
        return rel.estimateRowCount(mq);
    }

    public Double getRowCount(Intersect rel, RelMetadataQuery mq) {
        return rel.estimateRowCount(mq);
    }

    public Double getRowCount(Minus rel, RelMetadataQuery mq) {
        return rel.estimateRowCount(mq);
    }

    public double getRowCount(IgniteAggregate rel, RelMetadataQuery mq) {
        return rel.estimateRowCount(mq);
    }

    public double getRowCount(IgniteLimit rel, RelMetadataQuery mq) {
        return rel.estimateRowCount(mq);
    }

    @Nullable
    public static Double joinRowCount(RelMetadataQuery mq, Join rel) {
        Double max;
        if (rel.getJoinType() != JoinRelType.INNER && rel.getJoinType() != JoinRelType.LEFT && rel.getJoinType() != JoinRelType.RIGHT && rel.getJoinType() != JoinRelType.FULL && rel.getJoinType() != JoinRelType.SEMI) {
            return RelMdUtil.getJoinRowCount((RelMetadataQuery)mq, (Join)rel, (RexNode)rel.getCondition());
        }
        JoinInfo joinInfo = rel.analyzeCondition();
        if (joinInfo.pairs().isEmpty()) {
            return RelMdUtil.getJoinRowCount((RelMetadataQuery)mq, (Join)rel, (RexNode)rel.getCondition());
        }
        Double leftRowCount = mq.getRowCount(rel.getLeft());
        Double rightRowCount = mq.getRowCount(rel.getRight());
        if (leftRowCount == null || rightRowCount == null) {
            return null;
        }
        if ((leftRowCount <= 1.0 || rightRowCount <= 1.0) && (max = mq.getMaxRowCount((RelNode)rel)) != null && max <= 1.0) {
            return max;
        }
        Int2ObjectMap<KeyColumnOrigin> columnsFromLeft = IgniteMdRowCount.resolveOrigins(mq, rel.getLeft(), joinInfo.leftKeys);
        Int2ObjectMap<KeyColumnOrigin> columnsFromRight = IgniteMdRowCount.resolveOrigins(mq, rel.getRight(), joinInfo.rightKeys);
        HashMap<TablesPair, JoinContext> joinContexts = new HashMap<TablesPair, JoinContext>();
        for (IntPair joinKeys : joinInfo.pairs()) {
            KeyColumnOrigin leftKey = (KeyColumnOrigin)columnsFromLeft.get(joinKeys.source);
            KeyColumnOrigin rightKey = (KeyColumnOrigin)columnsFromRight.get(joinKeys.target);
            if (leftKey == null || rightKey == null) continue;
            joinContexts.computeIfAbsent(new TablesPair(leftKey.origin.getOriginTable(), rightKey.origin.getOriginTable()), key -> {
                IgniteTable leftTable = (IgniteTable)key.left.unwrap(IgniteTable.class);
                IgniteTable rightTable = (IgniteTable)key.right.unwrap(IgniteTable.class);
                assert (leftTable != null && rightTable != null);
                int leftPkSize = leftTable.keyColumns().size();
                int rightPkSize = rightTable.keyColumns().size();
                return new JoinContext(leftPkSize, rightPkSize);
            }).countKeys(leftKey, rightKey);
        }
        if (joinContexts.isEmpty()) {
            return RelMdUtil.getJoinRowCount((RelMetadataQuery)mq, (Join)rel, (RexNode)rel.getCondition());
        }
        Iterator it = joinContexts.values().iterator();
        JoinContext context = (JoinContext)it.next();
        while (it.hasNext()) {
            JoinContext nextContext = (JoinContext)it.next();
            if (nextContext.joinType().strength > context.joinType().strength) {
                context = nextContext;
            }
            if (context.joinType().strength != JoiningRelationType.PK_ON_PK.strength) continue;
            break;
        }
        if (context.joinType() == JoiningRelationType.UNKNOWN) {
            return RelMdUtil.getJoinRowCount((RelMetadataQuery)mq, (Join)rel, (RexNode)rel.getCondition());
        }
        double postFiltrationAdjustment = 1.0;
        switch (rel.getJoinType()) {
            case INNER: 
            case SEMI: {
                postFiltrationAdjustment = joinContexts.size() == 1 && joinInfo.isEqui() ? 1.0 : 0.7;
                break;
            }
        }
        double baseRowCount = 0.0;
        Double percentageAdjustment = null;
        if (context.joinType() == JoiningRelationType.PK_ON_PK) {
            if (rel.getJoinType() == JoinRelType.INNER || rel.getJoinType() == JoinRelType.SEMI) {
                if (leftRowCount > rightRowCount) {
                    baseRowCount = rightRowCount;
                    percentageAdjustment = mq.getPercentageOriginalRows(rel.getLeft());
                } else {
                    baseRowCount = leftRowCount;
                    percentageAdjustment = mq.getPercentageOriginalRows(rel.getRight());
                }
            } else if (rel.getJoinType() == JoinRelType.LEFT) {
                baseRowCount = leftRowCount;
            } else if (rel.getJoinType() == JoinRelType.RIGHT) {
                baseRowCount = rightRowCount;
            } else if (rel.getJoinType() == JoinRelType.FULL) {
                Double selectivity = mq.getSelectivity((RelNode)rel, rel.getCondition());
                if (selectivity == null) {
                    return RelMdUtil.getJoinRowCount((RelMetadataQuery)mq, (Join)rel, (RexNode)rel.getCondition());
                }
                baseRowCount = rightRowCount + leftRowCount;
                percentageAdjustment = 1.0 - selectivity;
            }
        } else if (context.joinType() == JoiningRelationType.FK_ON_PK) {
            if (rel.getJoinType() == JoinRelType.INNER || rel.getJoinType() == JoinRelType.SEMI) {
                baseRowCount = leftRowCount;
                percentageAdjustment = mq.getPercentageOriginalRows(rel.getRight());
            } else if (rel.getJoinType() == JoinRelType.LEFT || rel.getJoinType() == JoinRelType.RIGHT) {
                baseRowCount = leftRowCount;
            } else if (rel.getJoinType() == JoinRelType.FULL) {
                Double selectivity = mq.getSelectivity((RelNode)rel, rel.getCondition());
                if (selectivity == null) {
                    return RelMdUtil.getJoinRowCount((RelMetadataQuery)mq, (Join)rel, (RexNode)rel.getCondition());
                }
                baseRowCount = rightRowCount + leftRowCount;
                percentageAdjustment = 1.0 - selectivity;
            }
        } else if (rel.getJoinType() == JoinRelType.INNER || rel.getJoinType() == JoinRelType.SEMI) {
            baseRowCount = rightRowCount;
            percentageAdjustment = mq.getPercentageOriginalRows(rel.getLeft());
        } else if (rel.getJoinType() == JoinRelType.RIGHT || rel.getJoinType() == JoinRelType.LEFT) {
            baseRowCount = rightRowCount;
        } else if (rel.getJoinType() == JoinRelType.FULL) {
            Double selectivity = mq.getSelectivity((RelNode)rel, rel.getCondition());
            if (selectivity == null) {
                return RelMdUtil.getJoinRowCount((RelMetadataQuery)mq, (Join)rel, (RexNode)rel.getCondition());
            }
            baseRowCount = rightRowCount + leftRowCount;
            percentageAdjustment = 1.0 - selectivity;
        }
        if (percentageAdjustment == null) {
            percentageAdjustment = 1.0;
        }
        return baseRowCount * percentageAdjustment * postFiltrationAdjustment;
    }

    private static Int2ObjectMap<KeyColumnOrigin> resolveOrigins(RelMetadataQuery mq, RelNode joinShoulder, ImmutableIntList keys) {
        Int2ObjectOpenHashMap origins = new Int2ObjectOpenHashMap();
        Iterator iterator = keys.iterator();
        while (iterator.hasNext()) {
            IgniteTable table;
            RelColumnOrigin origin;
            int i = (Integer)iterator.next();
            if (origins.containsKey(i) || (origin = mq.getColumnOrigin(joinShoulder, i)) == null || (table = (IgniteTable)origin.getOriginTable().unwrap(IgniteTable.class)) == null) continue;
            int positionInKey = table.keyColumns().indexOf(origin.getOriginColumnOrdinal());
            origins.put(i, (Object)new KeyColumnOrigin(origin, positionInKey));
        }
        return origins;
    }

    private static class KeyColumnOrigin {
        private final RelColumnOrigin origin;
        private final int positionInKey;

        KeyColumnOrigin(RelColumnOrigin origin, int positionInKey) {
            this.origin = origin;
            this.positionInKey = positionInKey;
        }
    }

    private static class TablesPair {
        private final RelOptTable left;
        private final RelOptTable right;

        TablesPair(RelOptTable left, RelOptTable right) {
            this.left = left;
            this.right = right;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            TablesPair that = (TablesPair)o;
            return this.left == that.left && this.right == that.right;
        }

        public int hashCode() {
            return Objects.hash(this.left, this.right);
        }
    }

    private static class JoinContext {
        private final BitSet leftKeys = new BitSet();
        private final BitSet rightKeys = new BitSet();
        @Nullable
        private final BitSet commonKeys;

        JoinContext(int leftPkSize, int rightPkSize) {
            this.commonKeys = leftPkSize == rightPkSize ? new BitSet() : null;
            this.leftKeys.set(0, leftPkSize);
            this.rightKeys.set(0, rightPkSize);
            if (this.commonKeys != null) {
                assert (leftPkSize == rightPkSize);
                this.commonKeys.set(0, leftPkSize);
            }
        }

        void countKeys(KeyColumnOrigin left, KeyColumnOrigin right) {
            if (left.positionInKey >= 0) {
                this.leftKeys.clear(left.positionInKey);
            }
            if (right.positionInKey >= 0) {
                this.rightKeys.clear(right.positionInKey);
            }
            if (this.commonKeys != null && left.positionInKey == right.positionInKey && left.positionInKey >= 0) {
                this.commonKeys.clear(left.positionInKey);
            }
        }

        JoiningRelationType joinType() {
            if (this.commonKeys != null && this.commonKeys.isEmpty()) {
                return JoiningRelationType.PK_ON_PK;
            }
            if (this.rightKeys.isEmpty()) {
                return JoiningRelationType.FK_ON_PK;
            }
            if (this.leftKeys.isEmpty()) {
                return JoiningRelationType.PK_ON_FK;
            }
            return JoiningRelationType.UNKNOWN;
        }
    }

    private static enum JoiningRelationType {
        UNKNOWN(0),
        PK_ON_FK(JoiningRelationType.UNKNOWN.strength + 1),
        FK_ON_PK(JoiningRelationType.PK_ON_FK.strength + 1),
        PK_ON_PK(JoiningRelationType.FK_ON_PK.strength + 1);

        private final int strength;

        private JoiningRelationType(int strength) {
            this.strength = strength;
        }
    }
}

