package lsfusion.server.data.expr.join.query;

import lsfusion.base.BaseUtils;
import lsfusion.base.Result;
import lsfusion.base.col.SetFact;
import lsfusion.base.col.interfaces.immutable.*;
import lsfusion.base.col.interfaces.mutable.MSet;
import lsfusion.base.mutability.TwinImmutableObject;
import lsfusion.server.data.caches.AbstractOuterContext;
import lsfusion.server.data.caches.OuterContext;
import lsfusion.server.data.caches.hash.HashContext;
import lsfusion.server.data.expr.BaseExpr;
import lsfusion.server.data.expr.Expr;
import lsfusion.server.data.expr.join.where.GroupJoinsWhere;
import lsfusion.server.data.expr.join.where.KeyEqual;
import lsfusion.server.data.expr.join.where.WhereJoins;
import lsfusion.server.data.expr.key.KeyExpr;
import lsfusion.server.data.stat.*;
import lsfusion.server.data.translate.ExprTranslator;
import lsfusion.server.data.translate.MapTranslate;

public class GroupExprWhereJoins<K extends Expr> extends AbstractOuterContext<GroupExprWhereJoins<K>> {

    public static class Node<K extends Expr> extends AbstractOuterContext<Node<K>>  {
        public final ImMap<K, BaseExpr> mapExprs;
        public final KeyEqual keyEqual;
        public final WhereJoins joins;

        public Node(ImMap<K, BaseExpr> mapExprs, KeyEqual keyEqual, WhereJoins joins) {
            this.mapExprs = mapExprs;
            this.keyEqual = keyEqual;
            this.joins = joins;
        }

        public StatKeys<KeyExpr> getPartitionStatKeys(KeyStat keyStat, StatType type, StatKeys<KeyExpr> statKeys, ImSet<KeyExpr> allKeys, boolean useWhere, ImOrderMap<Expr, Boolean> orders) {
//            keyStat = keyEqual.getKeyStat(keyStat); // по идее и так оборачивается внутри

            ImSet<BaseExpr> group = mapExprs.values().toSet();
            StatKeys<BaseExpr> partitionStats = (useWhere ? joins : WhereJoins.EMPTY).pushStatKeys(statKeys).getStatKeys(group, keyStat, type, keyEqual); // joins

            WhereJoins adjJoins = joins.pushStatKeys(partitionStats);

            if(!keyEqual.isEmpty()) {
                adjJoins = adjJoins.and(keyEqual.getWhereJoins());
                keyStat = keyEqual.getKeyStat(keyStat);
            }

            StatKeys<KeyExpr> result = adjJoins.getStatKeys(allKeys, keyStat, type);

            if (!orders.isEmpty()) {
                ImSet<KeyExpr> partitionKeys = SetFact.filter(group, allKeys);
                if (partitionKeys.size() == allKeys.size() && WhereJoins.isPushedAll(BaseUtils.immutableCast(mapExprs), statKeys.getKeys())) {
                    Cost newCost = WhereJoins.getOrderCost(adjJoins, partitionKeys, result, keyStat, orders.keyOrderSet(), type, new Cost(result.getRows()), null);
                    if (newCost != null)
                        result = result.replaceCost(newCost);
                }
            }

            return result;
        }

        public StatKeys<K> getStatKeys(KeyStat keyStat, StatType type, StatKeys<K> statKeys, ImOrderSet<Expr> orders) {
//            keyStat = keyEqual.getKeyStat(keyStat);

            WhereJoins adjJoins = joins;
            if(statKeys != StatKeys.<K>NOPUSH()) {
                Result<ImRevMap<K, BaseExpr>> revMap = new Result<>();
                StatKeys<K> revStatKeys = statKeys.toRevMap(mapExprs.filterIncl(statKeys.getKeys()), revMap);
                adjJoins = adjJoins.pushStatKeys(revStatKeys.mapBack(revMap.result.reverse()));
            }

            ImSet<BaseExpr> group = mapExprs.values().toSet();

            if(!keyEqual.isEmpty()) {
                adjJoins = adjJoins.and(keyEqual.getWhereJoins());
                keyStat = keyEqual.getKeyStat(keyStat);
            }

            StatKeys<BaseExpr> result = adjJoins.getStatKeys(group, keyStat, type);

            if(statKeys != StatKeys.<K>NOPUSH() && !orders.isEmpty() && WhereJoins.isPushedAll(mapExprs, statKeys.getKeys())) {
                Cost newCost = WhereJoins.getOrderCost(adjJoins, group, result, keyStat, orders, type, new Cost(result.getRows()), null);
                if (newCost != null)
                    result = result.replaceCost(newCost);
            }

            return result.mapBack(mapExprs);
        }

        protected ImSet<OuterContext> calculateOuterDepends() {
            return SetFact.mergeSet(mapExprs.keys(), BaseUtils.<ImSet<OuterContext>>immutableCast(mapExprs.values().toSet())).merge(keyEqual).merge(joins);
        }

        protected Node<K> translate(MapTranslate translator) {
            return new Node<>(translator.translateMap(mapExprs), keyEqual.translateOuter(translator), joins.translateOuter(translator));
        }

        public int hash(HashContext hash) {
            return 31 * (31 * AbstractOuterContext.hashMapOuter(mapExprs, hash) + keyEqual.hashOuter(hash)) + joins.hashOuter(hash);
        }

        protected boolean calcTwins(TwinImmutableObject o) {
            return mapExprs.equals(((Node<K>)o).mapExprs) && keyEqual.equals(((Node<K>)o).keyEqual) && joins.equals(((Node<K>)o).joins);
        }
    }

    private ImSet<Node<K>> nodes;

    public GroupExprWhereJoins(ImSet<Node<K>> nodes) {
        this.nodes = nodes;
    }
    
    public void addAll(MSet<Node<K>> mResult) {
        mResult.addAll(nodes);
    }

    public StatKeys<KeyExpr> getPartitionStatKeys(final KeyStat keyStat, final StatType type, final StatKeys<KeyExpr> statKeys, final boolean useWhere, final ImSet<KeyExpr> allKeys, ImOrderMap<Expr, Boolean> orders) {
        return StatKeys.or(nodes, value -> value.getPartitionStatKeys(keyStat, type, statKeys, allKeys, useWhere, orders), allKeys);
    }

    public StatKeys<K> getStatKeys(final KeyStat keyStat, final StatType type, final StatKeys<K> statKeys, ImSet<K> allKeys, ImOrderSet<Expr> orders) {
        return StatKeys.or(nodes, value -> value.getStatKeys(keyStat, type, statKeys, orders), allKeys);
    }

    // GroupJoinsWhere может и всегда приходит без Where
    public static <K extends Expr> GroupExprWhereJoins<K> create(ImCol<GroupJoinsWhere> whereJoins, final ImMap<K, BaseExpr> mapExprs, StatType statType, boolean forcePackReduce) {
        MSet<Node<K>> mResult = SetFact.mSet();
        for(int i=0,size=whereJoins.size();i<size;i++) {
            GroupJoinsWhere joinsWhere = whereJoins.get(i);
            if(joinsWhere.keyEqual.isEmpty())
                mResult.add(new Node<>(mapExprs, joinsWhere.keyEqual, joinsWhere.joins));
            else {
                ExprTranslator translator = joinsWhere.keyEqual.getTranslator();
                ImMap<K, Expr> transMapExprs = translator.translate(mapExprs);
                ImMap<K, BaseExpr> transMapBaseExprs = BaseExpr.onlyBaseExprs(transMapExprs);
                if(transMapBaseExprs != null)
                    mResult.add(new Node<>(transMapBaseExprs, joinsWhere.keyEqual, joinsWhere.joins));
                else
                    joinsWhere.getFullWhere().getGroupExprWhereJoins(transMapExprs, statType, forcePackReduce).addAll(mResult);
            }
        }
        return new GroupExprWhereJoins<>(mResult.immutable());

//        return new GroupExprWhereJoins<>(whereJoins.mapMergeSetValues(new Function<Node<K>, GroupJoinsWhere>() {
//            public Node<K> apply(GroupJoinsWhere value) {
//                return new Node<>(mapExprs, value.keyEqual, value.joins);
//            }
//        }));
    }

    private static final GroupExprWhereJoins EMPTY = new GroupExprWhereJoins(SetFact.EMPTY());
    public static <K extends Expr> GroupExprWhereJoins<K> EMPTY() {
        return EMPTY;
    }

    public GroupExprWhereJoins<K> merge(GroupExprWhereJoins<K> joins) {
        return new GroupExprWhereJoins<>(nodes.merge(joins.nodes));
    }

    @Override
    protected ImSet<OuterContext> calculateOuterDepends() {
        return BaseUtils.immutableCast(nodes);
    }

    @Override
    protected GroupExprWhereJoins<K> translate(MapTranslate translator) {
        return new GroupExprWhereJoins<>(translator.translateSet(nodes));
    }

    @Override
    public int hash(HashContext hash) {
        return AbstractOuterContext.hashOuter(nodes, hash);
    }

    @Override
    protected boolean calcTwins(TwinImmutableObject o) {
        return nodes.equals(((GroupExprWhereJoins<K>)o).nodes);
    }
}
